diff --git a/IFChatPromptNode.py b/IFChatPromptNode.py index 653b779..c616909 100644 --- a/IFChatPromptNode.py +++ b/IFChatPromptNode.py @@ -1,1114 +1,1114 @@ -# IFChatPromptNode.py -import os -import sys -import json -import torch -import shutil -import base64 -import platform -import importlib -import subprocess -import numpy as np -import folder_paths -from PIL import Image -import yaml -from io import BytesIO -import asyncio -from typing import List, Union, Dict, Any, Tuple, Optional -from .agent_tool import AgentTool -from .send_request import send_request -#from .transformers_api import TransformersModelManager -import tempfile -import threading -from aiohttp import web -from .graphRAG_module import GraphRAGapp -from .colpaliRAG_module import colpaliRAGapp -from .superflorence import FlorenceModule -from .utils import get_api_key, get_models, validate_models, clean_text, process_mask, load_placeholder_image, process_images_for_comfy -#from byaldi import RAGMultiModalModel -# Set up logging -import logging -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) -# Add the ComfyUI directory to the Python path -comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) -sys.path.insert(0, comfy_path) - -ifchat_prompt_node = None - -try: - from server import PromptServer - - @PromptServer.instance.routes.post("/IF_ChatPrompt/get_llm_models") - async def get_llm_models_endpoint(request): - data = await request.json() - llm_provider = data.get("llm_provider") - engine = llm_provider - base_ip = data.get("base_ip") - port = data.get("port") - external_api_key = data.get("external_api_key") - - logger.debug(f"Received request for LLM models. Provider: {llm_provider}, External API key provided: {bool(external_api_key)}") - - if external_api_key: - api_key = external_api_key - logger.debug("Using provided external LLM API key") - else: - api_key_name = f"{llm_provider.upper()}_API_KEY" - try: - api_key = get_api_key(api_key_name, engine) - logger.debug("Using API key from environment or .env file") - except ValueError: - logger.warning(f"No API key found for {llm_provider}. Attempting to proceed without an API key.") - api_key = None - - models = get_models(engine, base_ip, port, api_key) - logger.debug(f"Fetched {len(models)} models for {llm_provider}") - return web.json_response(models) - - @PromptServer.instance.routes.post("/IF_ChatPrompt/get_embedding_models") - async def get_embedding_models_endpoint(request): - data = await request.json() - embedding_provider = data.get("embedding_provider") - engine = embedding_provider - base_ip = data.get("base_ip") - port = data.get("port") - external_api_key = data.get("external_api_key") - - logger.debug(f"Received request for LLM models. Provider: {embedding_provider}, External API key provided: {bool(external_api_key)}") - - if external_api_key: - api_key = external_api_key - logger.debug("Using provided external LLM API key") - else: - api_key_name = f"{embedding_provider.upper()}_API_KEY" - try: - api_key = get_api_key(api_key_name, engine) - logger.debug("Using API key from environment or .env file") - except ValueError: - logger.warning(f"No API key found for {embedding_provider}. Attempting to proceed without an API key.") - api_key = None - - models = get_models(engine, base_ip, port, api_key) - logger.debug(f"Fetched {len(models)} models for {embedding_provider}") - return web.json_response(models) - - @PromptServer.instance.routes.post("/IF_ChatPrompt/upload_file") - async def upload_file_route(request): - try: - reader = await request.multipart() - - rag_folder_name = None - file_content = None - filename = None - - # Process all parts of the multipart request - while True: - part = await reader.next() - if part is None: - break - if part.name == "rag_root_dir": - rag_folder_name = await part.text() - elif part.filename: - filename = part.filename - file_content = await part.read() - - if not filename or not file_content or not rag_folder_name: - return web.json_response({"status": "error", "message": "Missing file, filename, or RAG folder name"}) - - node = IFChatPrompt() - input_dir = os.path.join(node.rag_dir, rag_folder_name, "input") - - if not os.path.exists(input_dir): - os.makedirs(input_dir, exist_ok=True) - - file_path = os.path.join(input_dir, filename) - - with open(file_path, 'wb') as f: - f.write(file_content) - - logger.info(f"File uploaded to: {file_path}") - return web.json_response({"status": "success", "message": f"File uploaded to: {file_path}"}) - - except Exception as e: - logger.error(f"Error in upload_file_route: {str(e)}") - return web.json_response({"status": "error", "message": f"Error uploading file: {str(e)}"}) - - @PromptServer.instance.routes.post("/IF_ChatPrompt/setup_and_initialize") - async def setup_and_initialize(request): - global ifchat_prompt_node - - data = await request.json() - folder_name = data.get('folder_name', 'rag_data') - - if ifchat_prompt_node is None: - ifchat_prompt_node = IFChatPrompt() - - init_result = await ifchat_prompt_node.graphrag_app.setup_and_initialize_folder(folder_name, data) - - ifchat_prompt_node.rag_folder_name = folder_name - ifchat_prompt_node.colpali_app.set_rag_root_dir(folder_name) - - return web.json_response(init_result) - - @PromptServer.instance.routes.post("/IF_ChatPrompt/run_indexer") - async def run_indexer_endpoint(request): - try: - data = await request.json() - logger.debug(f"Received indexing request with data: {data}") - - global ifchat_prompt_node # Access the global instance - - # Set the rag_root_dir in both modules using the global instance - ifchat_prompt_node.graphrag_app.set_rag_root_dir(data.get('rag_folder_name')) - ifchat_prompt_node.colpali_app.set_rag_root_dir(data.get('rag_folder_name')) - - query_type = data.get('mode_type') - logger.debug(f"Query type: {query_type}") - - logger.debug(f"Starting indexing process for query type: {query_type}") - - # Initialize the colpali_model before calling insert, using the global instance - if query_type == 'colpali' or query_type == 'colqwen2' or query_type == 'colpali-v1.2': - _ = ifchat_prompt_node.colpali_app.get_colpali_model(query_type) # This will load or retrieve the cached model - result = await ifchat_prompt_node.colpali_app.insert() - else: - result = await ifchat_prompt_node.graphrag_app.insert() - - logger.debug(f"Indexing process completed with result: {result}") - - if result: - return web.json_response({"status": "success", "message": f"Indexing complete for {query_type}"}) - else: - return web.json_response({"status": "error", "message": "Indexing failed. Check server logs."}, status=500) - - except Exception as e: - logger.error(f"Error in run_indexer_endpoint: {str(e)}") - return web.json_response({"status": "error", "message": f"Error during indexing: {str(e)}"}, status=500) - - @PromptServer.instance.routes.post("/IF_ChatPrompt/process_chat") - async def process_chat_endpoint(request): - try: - data = await request.json() - - # Set default values for required arguments if not provided - defaults = { - "prompt": "", - "assistant": "Cortana", # Default assistant - "neg_prompt": "Default", # Default negative prompt - "embellish_prompt": "Default", # Default embellishment - "style_prompt": "Default", # Default style - "llm_provider": "ollama", - "llm_model": "", - "base_ip": "localhost", - "port": "11434", - "embedding_model": "", - "embedding_provider": "sentence_transformers" - } - - # Update data with defaults for missing keys - for key, default_value in defaults.items(): - if key not in data: - data[key] = default_value - - global ifchat_prompt_node - result = await ifchat_prompt_node.process_chat(**data) - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error in process_chat_endpoint: {str(e)}") - return web.json_response({ - "status": "error", - "message": f"Error processing chat: {str(e)}", - "Question": data.get("prompt", ""), - "Response": f"Error: {str(e)}", - "Negative": "", - "Tool_Output": None, - "Retrieved_Image": None, - "Mask": None - }, status=500) - - @PromptServer.instance.routes.post("/IF_ChatPrompt/load_index") - async def load_index_route(request): - try: - data = await request.json() - index_name = data.get('rag_folder_name') - query_type = data.get('query_type') - - if not index_name: - logger.error("No index name provided in the request.") - return web.json_response({ - "status": "error", - "message": "No index name provided" - }) - - # Check if index exists in .byaldi directory - byaldi_index_path = os.path.join(".byaldi", index_name) - if not os.path.exists(byaldi_index_path): - logger.error(f"Index not found in .byaldi: {byaldi_index_path}") - return web.json_response({ - "status": "error", - "message": f"Index {index_name} does not exist" - }) - - try: - global ifchat_prompt_node - if ifchat_prompt_node is None: - logger.debug("Initializing IFChatPrompt instance.") - ifchat_prompt_node = IFChatPrompt() - - if query_type in ['colpali', 'colqwen2', 'colpali-v1.2']: - logger.debug(f"Loading model for query type: {query_type}") - - # Clear any existing cached index - ifchat_prompt_node.colpali_app.cleanup_index() - - # First get the base model - colpali_model = ifchat_prompt_node.colpali_app.get_colpali_model(query_type) - - if colpali_model: - # Load and cache the new index - model = await ifchat_prompt_node.colpali_app._prepare_model(query_type, index_name) - if not model: - raise ValueError("Failed to load and cache index") - - # Set the RAG root directory - ifchat_prompt_node.colpali_app.set_rag_root_dir(index_name) - - logger.info(f"Successfully loaded and cached index: {index_name}") - return web.json_response({ - "status": "success", - "message": f"Successfully loaded index: {index_name}", - "rag_root_dir": index_name - }) - else: - logger.error("Failed to initialize ColPali model.") - raise ValueError("Failed to initialize ColPali model") - - else: - logger.error(f"Unsupported query type: {query_type}") - return web.json_response({ - "status": "error", - "message": f"Query type {query_type} not supported for loading indexes" - }) - - except Exception as e: - logger.error(f"Error loading index {index_name}: {str(e)}") - return web.json_response({ - "status": "error", - "message": f"Error loading index: {str(e)}" - }) - - except Exception as e: - logger.error(f"Error in load_index_route: {str(e)}") - return web.json_response({ - "status": "error", - "message": f"Error processing request: {str(e)}" - }) - - # Add this with the other routes - @PromptServer.instance.routes.post("/IF_ChatPrompt/delete_index") - async def delete_index_route(request): - try: - data = await request.json() - index_name = data.get('rag_folder_name') - - if not index_name: - return web.json_response({ - "status": "error", - "message": "No index name provided" - }) - - # Path to the index - index_path = os.path.join(".byaldi", index_name) - - if not os.path.exists(index_path): - return web.json_response({ - "status": "error", - "message": f"Index {index_name} does not exist" - }) - - # Delete the index directory - try: - shutil.rmtree(index_path) - logger.info(f"Successfully deleted index: {index_name}") - return web.json_response({ - "status": "success", - "message": f"Successfully deleted index: {index_name}" - }) - except Exception as e: - logger.error(f"Error deleting index {index_name}: {str(e)}") - return web.json_response({ - "status": "error", - "message": f"Error deleting index: {str(e)}" - }) - - except Exception as e: - logger.error(f"Error in delete_index_route: {str(e)}") - return web.json_response({ - "status": "error", - "message": f"Error processing request: {str(e)}" - }) - -except AttributeError: - print("PromptServer.instance not available. Skipping route decoration for IF_ChatPrompt.") - -class IFChatPrompt: - - def __init__(self): - self.base_ip = "localhost" - self.port = "11434" - self.llm_provider = "ollama" - self.embedding_provider = "sentence_transformers" - self.llm_model = "" - self.embedding_model = "" - self.assistant = "None" - self.random = False - - self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "rag") - self.presets_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "presets") - - self.stop_file = os.path.join(self.presets_dir, "stop_strings.json") - self.assistants_file = os.path.join(self.presets_dir, "assistants.json") - self.neg_prompts_file = os.path.join(self.presets_dir, "neg_prompts.json") - self.embellish_prompts_file = os.path.join(self.presets_dir, "embellishments.json") - self.style_prompts_file = os.path.join(self.presets_dir, "style_prompts.json") - self.tasks_file = os.path.join(self.presets_dir, "florence_prompts.json") - self.agents_dir = os.path.join(self.presets_dir, "agents") - - self.agent_tools = self.load_agent_tools() - self.stop_strings = self.load_presets(self.stop_file) - self.assistants = self.load_presets(self.assistants_file) - self.neg_prompts = self.load_presets(self.neg_prompts_file) - self.embellish_prompts = self.load_presets(self.embellish_prompts_file) - self.style_prompts = self.load_presets(self.style_prompts_file) - self.florence_prompts = self.load_presets(self.tasks_file) - - self.keep_alive = False - self.seed = 94687328150 - self.messages = [] - self.history_steps = 10 - self.external_api_key = "" - self.tool_input = "" - self.prime_directives = None - self.rag_folder_name = "rag_data" - self.graphrag_app = GraphRAGapp() - self.colpali_app = colpaliRAGapp() - self.fix_json = True - self.cached_colpali_model = None - self.florence_app = FlorenceModule() - self.florence_models = {} - self.query_type = "global" - self.enable_RAG = False - self.clear_history = False - self.mode = False - self.tool = "None" - self.preset = "Default" - self.precision = "fp16" - self.task = None - self.attention = "sdpa" - self.aspect_ratio = "16:9" - self.top_k_search = 3 - - self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "placeholder.png") - - if not os.path.exists(self.placeholder_image_path): - placeholder = Image.new('RGB', (512, 512), color=(73, 109, 137)) - os.makedirs(os.path.dirname(self.placeholder_image_path), exist_ok=True) - placeholder.save(self.placeholder_image_path) - - def load_presets(self, file_path): - with open(file_path, 'r') as f: - presets = json.load(f) - return presets - - def load_agent_tools(self): - os.makedirs(self.agents_dir, exist_ok=True) - agent_tools = {} - try: - for filename in os.listdir(self.agents_dir): - if filename.endswith('.json'): - full_path = os.path.join(self.agents_dir, filename) - with open(full_path, 'r') as f: - try: - data = json.load(f) - if 'output_type' not in data: - data['output_type'] = None - agent_tool = AgentTool(**data) - agent_tool.load() - if agent_tool._class_instance is not None: - if agent_tool.python_function: - agent_tools[agent_tool.name] = agent_tool - else: - print(f"Warning: Agent tool {agent_tool.name} in {filename} does not have a python_function defined.") - else: - print(f"Failed to create class instance for {filename}") - except json.JSONDecodeError: - print(f"Error: Invalid JSON in {filename}") - except Exception as e: - print(f"Error loading {filename}: {str(e)}") - return agent_tools - except Exception as e: - print(f"Warning: Error accessing agent tools directory: {str(e)}") - return {} - - async def process_chat( - self, - prompt, - llm_provider, - llm_model, - base_ip, - port, - assistant, - neg_prompt, - embellish_prompt, - style_prompt, - embedding_model, - embedding_provider, - external_api_key="", - temperature=0.7, - max_tokens=2048, - seed=0, - random=False, - history_steps=10, - keep_alive=False, - top_k=40, - top_p=0.2, - repeat_penalty=1.1, - stop_string=None, - images=None, - mode=True, - clear_history=False, - text_cleanup=True, - tool=None, - tool_input=None, - prime_directives=None, - enable_RAG=False, - query_type="global", - preset="Default", - rag_folder_name=None, - task=None, - fill_mask=False, - output_mask_select="", - precision="fp16", - attention="sdpa", - aspect_ratio="16:9", - top_k_search=3 - ): - - if external_api_key != "": - llm_api_key = external_api_key - else: - llm_api_key = get_api_key(f"{llm_provider.upper()}_API_KEY", llm_provider) - - print(f"LLM API key: {llm_api_key[:5]}...") - if prime_directives is not None: - system_message_str = prime_directives - else: - system_message = self.assistants.get(assistant, "") - system_message_str = json.dumps(system_message) - - # Validate LLM model - validate_models(llm_model, llm_provider, "LLM", base_ip, port, llm_api_key) - - # Validate embedding model - validate_models(embedding_model, embedding_provider, "embedding", base_ip, port, llm_api_key) - - # Handle history - if clear_history: - self.messages = [] - elif history_steps > 0: - self.messages = self.messages[-history_steps:] - - messages = self.messages - - # Handle stop - if stop_string is None or stop_string == "None": - stop_content = None - else: - stop_content = self.stop_strings.get(stop_string, None) - stop = stop_content - - if llm_provider not in ["ollama", "llamacpp", "vllm", "lmstudio", "gemeni"]: - if llm_provider == "kobold": - stop = stop_content + \ - ["\n\n\n\n\n"] if stop_content else ["\n\n\n\n\n"] - elif llm_provider == "mistral": - stop = stop_content + \ - ["\n\n"] if stop_content else ["\n\n"] - else: - stop = stop_content if stop_content else None - # Handle tools - try: - if tool and tool != "None": - selected_tool = self.agent_tools.get(tool) - if not selected_tool: - raise ValueError(f"Invalid agent tool selected: {tool}") - - # Prepare tool execution message - tool_message = f"Execute the {tool} tool with the following input: {prompt}" - system_prompt = json.dumps(selected_tool.system_prompt) - - # Send request to LLM for tool execution - generated_text =await send_request( - llm_provider=llm_provider, - base_ip=base_ip, - port=port, - images=images, - model=llm_model, - system_message=system_prompt, - user_message=tool_message, - messages=messages, - seed=seed, - temperature=temperature, - max_tokens=max_tokens, - random=random, - top_k=top_k, - top_p=top_p, - repeat_penalty=repeat_penalty, - stop=stop, - keep_alive=keep_alive, - llm_api_key=llm_api_key, - ) - # Parse the generated text for function calls - function_call = None - try: - response_data = json.loads(generated_text) - if 'function_call' in response_data: - function_call = response_data['function_call'] - generated_text = response_data['content'] - except json.JSONDecodeError: - pass # The response wasn't JSON, so it's just the generated text - - # Execute the tool with the LLM's response - tool_args = { - "input": prompt, - "llm_response": generated_text, - "function_call": function_call, - "omni_input": tool_input, - "name": selected_tool.name, - "description": selected_tool.description, - "system_prompt": selected_tool.system_prompt - } - tool_result = selected_tool.execute(tool_args) - - # Update messages - messages.append({"role": "user", "content": prompt}) - messages.append({ - "role": "assistant", - "content": json.dumps(tool_result) if isinstance(tool_result, dict) else str(tool_result) - }) - - # Process the tool output - if isinstance(tool_result, dict): - if "error" in tool_result: - generated_text = f"Error in {tool}: {tool_result['error']}" - tool_output = None - elif selected_tool.output_type and selected_tool.output_type in tool_result: - tool_output = tool_result[selected_tool.output_type] - generated_text = f"Agent {tool} executed successfully. Output generated." - else: - tool_output = tool_result - generated_text = str(tool_output) - else: - tool_output = tool_result - generated_text = str(tool_output) - - return { - "Question": prompt, - "Response": generated_text, - "Negative": self.neg_prompts.get(neg_prompt, ""), - "Tool_Output": tool_output, - "Retrieved_Image": None # No image retrieved in tool execution - } - else: - response = await self.generate_response( - enable_RAG, - query_type, - prompt, - preset, - llm_provider, - base_ip, - port, - images, - llm_model, - system_message_str, - messages, - temperature, - max_tokens, - random, - top_k, - top_p, - repeat_penalty, - stop, - seed, - keep_alive, - llm_api_key, - task, - fill_mask, - output_mask_select, - precision, - attention - ) - - generated_text = response.get("Response") - selected_neg_prompt_name = neg_prompt - omni = response.get("Tool_Output") - retrieved_image = response.get("Retrieved_Image") - retrieved_mask = response.get("Mask") - - - # Update messages - messages.append({"role": "user", "content": prompt}) - messages.append({"role": "assistant", "content": generated_text}) - - text_result = str(generated_text).strip() - - if mode: - embellish_content = self.embellish_prompts.get(embellish_prompt, "").strip() - style_content = self.style_prompts.get(style_prompt, "").strip() - - lines = [line.strip() for line in text_result.split('\n') if line.strip()] - combined_prompts = [] - - for line in lines: - if text_cleanup: - line = clean_text(line) - formatted_line = f"{embellish_content} {line} {style_content}".strip() - combined_prompts.append(formatted_line) - - combined_prompt = "\n".join(formatted_line for formatted_line in combined_prompts) - # Handle negative prompts - if selected_neg_prompt_name == "AI_Fill": - try: - neg_system_message = self.assistants.get("NegativePromptEngineer") - if not neg_system_message: - logger.error("NegativePromptEngineer not found in assistants configuration") - negative_prompt = "Error: NegativePromptEngineer not configured" - else: - user_message = f"Generate negative prompts for the following prompt:\n{text_result}" - - system_message_str = json.dumps(neg_system_message) - - logger.info(f"Requesting negative prompts for prompt: {text_result[:100]}...") - - neg_response = await send_request( - llm_provider=llm_provider, - base_ip=base_ip, - port=port, - images=None, - llm_model=llm_model, - system_message=system_message_str, - user_message=user_message, - messages=[], # Fresh context for negative generation - seed=seed, - temperature=temperature, - max_tokens=max_tokens, - random=random, - top_k=top_k, - top_p=top_p, - repeat_penalty=repeat_penalty, - stop=stop, - keep_alive=keep_alive, - llm_api_key=llm_api_key - ) - - logger.debug(f"Received negative prompt response: {neg_response}") - - if neg_response: - negative_lines = [] - for line in neg_response.split('\n'): - line = line.strip() - if line: - negative_lines.append(line) - - while len(negative_lines) < len(lines): - negative_lines.append(negative_lines[-1] if negative_lines else "") - negative_lines = negative_lines[:len(lines)] - - negative_prompt = "\n".join(negative_lines) - else: - negative_prompt = "Error: Empty response from LLM" - except Exception as e: - logger.error(f"Error generating negative prompts: {str(e)}", exc_info=True) - negative_prompt = f"Error generating negative prompts: {str(e)}" - - elif neg_prompt != "None": - neg_content = self.neg_prompts.get(neg_prompt, "").strip() - negative_lines = [neg_content for _ in range(len(lines))] - negative_prompt = "\n".join(negative_lines) - else: - negative_prompt = "" - - else: - combined_prompt = text_result - negative_prompt = "" - - try: - if isinstance(retrieved_image, torch.Tensor): - # Ensure it's in the correct format (B, C, H, W) - if retrieved_image.dim() == 3: # Single image (C, H, W) - image_tensor = retrieved_image.unsqueeze(0) # Add batch dimension - else: - image_tensor = retrieved_image # Already batched - - # Create matching batch masks - batch_size = image_tensor.shape[0] - height = image_tensor.shape[2] - width = image_tensor.shape[3] - - # Create white masks (all ones) for each image in batch - mask_tensor = torch.ones((batch_size, 1, height, width), - dtype=torch.float32, - device=image_tensor.device) - - if retrieved_mask is not None: - # If we have masks, process them to match the batch - if isinstance(retrieved_mask, torch.Tensor): - if retrieved_mask.dim() == 3: # Single mask - mask_tensor = retrieved_mask.unsqueeze(0) - else: - mask_tensor = retrieved_mask - else: - # Process retrieved_mask if it's not a tensor - mask_tensor = process_mask(retrieved_mask, image_tensor) - else: - image_tensor, default_mask_tensor = process_images_for_comfy( - retrieved_image, - self.placeholder_image_path - ) - mask_tensor = default_mask_tensor - - if retrieved_mask is not None: - mask_tensor = process_mask(retrieved_mask, image_tensor) - return ( - prompt, - combined_prompt, - negative_prompt, - omni, - image_tensor, - mask_tensor, - ) - - except Exception as e: - logger.error(f"Exception in image processing: {str(e)}", exc_info=True) - placeholder_image, placeholder_mask = load_placeholder_image(self.placeholder_image_path) - return ( - prompt, - f"Error: {str(e)}", - "", - None, - placeholder_image, - placeholder_mask - ) - - except Exception as e: - logger.error(f"Exception occurred in process_chat: {str(e)}", exc_info=True) - placeholder_image, placeholder_mask = load_placeholder_image(self.placeholder_image_path) - return ( - prompt, - f"Error: {str(e)}", - "", - None, - placeholder_image, - placeholder_mask - ) - - async def generate_response( - self, - enable_RAG, - query_type, - prompt, - preset, - llm_provider, - base_ip, - port, - images, - llm_model, - system_message_str, - messages, - temperature, - max_tokens, - random, - top_k, - top_p, - repeat_penalty, - stop, - seed, - keep_alive, - llm_api_key, - task=None, - fill_mask=False, - output_mask_select="", - precision="fp16", - attention="sdpa", - ): - response_strategies = { - "graphrag": self.graphrag_app.query, - "colpali": self.colpali_app.query, - "florence": self.florence_app.run_florence, - "normal": lambda: send_request( - llm_provider=llm_provider, - base_ip=base_ip, - port=port, - images=images, - llm_model=llm_model, - system_message=system_message_str, - user_message=prompt, - messages=messages, - seed=seed, - temperature=temperature, - max_tokens=max_tokens, - random=random, - top_k=top_k, - top_p=top_p, - repeat_penalty=repeat_penalty, - stop=stop, - keep_alive=keep_alive, - llm_api_key=llm_api_key, - tools=None, - tool_choice=None, - precision=precision, - attention=attention - ), - } - florence_tasks = list(self.florence_prompts.keys()) - if enable_RAG: - if query_type == "colpali" or query_type == "colpali-v1.2" or query_type == "colqwen2": - strategy = "colpali" - else: # For "global", "local", and "naive" query types - strategy = "graphrag" - elif task and task.lower() != 'none' and task in florence_tasks: - strategy = "florence" - else: - strategy = "normal" - - print(f"Strategy: {strategy}") - - try: - if strategy == "colpali": - # Ensure the model is loaded before querying - if self.cached_colpali_model is None: - self.cached_colpali_model = self.colpali_app.get_colpali_model(query_type) - response = await response_strategies[strategy](prompt=prompt, query_type=query_type, system_message_str=system_message_str) - return response - elif strategy == "graphrag": - response = await response_strategies[strategy](prompt=prompt, query_type=query_type, preset=preset) - return { - "Question": prompt, - "Response": response[0], - "Negative": "", - "Tool_Output": response[1], - "Retrieved_Image": None, - "Mask": None - } - elif strategy == "florence": - task_content = self.florence_prompts.get(task, "") - response = await response_strategies[strategy]( - images=images, - task=task, - task_prompt=task_content, - llm_model=llm_model, - precision=precision, - attention=attention, - fill_mask=fill_mask, - output_mask_select=output_mask_select, - keep_alive=keep_alive, - max_new_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - repetition_penalty=repeat_penalty, - seed=seed, - text_input=prompt, - ) - print("Florence response:", response) - return response - else: - response = await response_strategies[strategy]() - print("Normal response:", response) - return { - "Question": prompt, - "Response": response, - "Negative": "", - "Tool_Output": None, - "Retrieved_Image": None, - "Mask": None - } - - except Exception as e: - logger.error(f"Error processing strategy: {strategy}") - return { - "Question": prompt, - "Response": f"Error processing task: {str(e)}", - "Negative": "", - "Tool_Output": {"error": str(e)}, - "Retrieved_Image": None, - "Mask": None - } - - def process_chat_wrapper(self, *args, **kwargs): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - logger.debug(f"process_chat_wrapper kwargs: {kwargs}") - logger.debug(f"External LLM API Key: {kwargs.get('external_api_key', 'Not provided')}") - return loop.run_until_complete(self.process_chat(*args, **kwargs)) - - @classmethod - def INPUT_TYPES(cls): - node = cls() - return { - "required": { - "prompt": ("STRING", {"multiline": True, "default": "", "tooltip": "The main text input for the chat or query."}), - "llm_provider": (["xai","llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "transformers"], {"default": node.llm_provider, "tooltip": "The provider of the language model to be used."}), - "llm_model": ((), {"tooltip": "The specific language model to be used for processing."}), - "base_ip": ("STRING", {"default": node.base_ip, "tooltip": "IP address of the LLM server."}), - "port": ("STRING", {"default": node.port, "tooltip": "Port number for the LLM server connection."}), - }, - "optional": { - "images": ("IMAGE", {"list": True, "tooltip": "Input image(s) for visual processing or context."}), - "precision": (['fp16','bf16','fp32','int8','int4'],{"default": 'bf16', "tooltip": "Select preccision on Transformer models."}), - "attention": (['flash_attention_2','sdpa','xformers', 'Shrek_COT_o1'],{"default": 'sdpa', "tooltip": "Select attention mechanism on Transformer models."}), - "assistant": ([name for name in node.assistants.keys()], {"default": node.assistant, "tooltip": "The pre-defined assistant personality to use for responses."}), - "tool": (["None"] + [name for name in node.agent_tools.keys()], {"default": "None", "tooltip": "Selects a specific tool or agent for task execution."}), - "temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1, "tooltip": "Controls randomness in output generation. Higher values increase creativity but may reduce coherence."}), - "max_tokens": ("INT", {"default": 2048, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Maximum number of tokens to generate in the response."}), - "top_k": ("INT", {"default": 40, "min": 0, "max": 100, "tooltip": "Limits the next token selection to the K most likely tokens."}), - "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.1, "tooltip": "Cumulative probability cutoff for token selection."}), - "repeat_penalty": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Penalizes repetition in generated text."}), - "stop_string": ([name for name in node.stop_strings.keys()], {"tooltip": "Specifies a string at which text generation should stop."}), - "seed": ("INT", {"default": 94687328150, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Random seed for reproducible outputs."}), - "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature", "tooltip": "Toggles between using a fixed seed or temperature-based randomness."}), - "history_steps": ("INT", {"default": 10, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Number of previous conversation turns to consider for context."}), - "clear_history": ("BOOLEAN", {"default": False, "label_on": "Clear History", "label_off": "Keep History", "tooltip": "Option to clear or retain conversation history."}), - "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps Model on Memory", "label_off": "Unloads Model from Memory", "tooltip": "Determines whether to keep the model loaded in memory between calls."}), - "text_cleanup": ("BOOLEAN", {"default": True, "label_on": "Clean Response", "label_off": "Raw Text", "tooltip": "Applies text cleaning to the generated output."}), - "mode": ("BOOLEAN", {"default": False, "label_on": "Using SD Mode", "label_off": "Using Chat Mode", "tooltip": "Switches between Stable Diffusion prompt generation and standard chat mode."}), - "embellish_prompt": ([name for name in node.embellish_prompts.keys()], {"tooltip": "Adds pre-defined embellishments to the prompt."}), - "style_prompt": ([name for name in node.style_prompts.keys()], {"tooltip": "Applies a pre-defined style to the prompt."}), - "neg_prompt": ([name for name in node.neg_prompts.keys()], {"tooltip": "Adds a negative prompt to guide what should be avoided in generation."}), - "fill_mask": ("BOOLEAN", {"default": False, "label_on": "Fill Mask", "label_off": "No Fill", "tooltip": "Option to fill masks for Florence tasks."}), - "output_mask_select": ("STRING", {"default": ""}), - "task": ([name for name in node.florence_prompts.keys()], {"default": "None", "tooltip": "Select a Florence task."}), - "embedding_provider": (["llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "sentence_transformers"], {"default": node.embedding_provider, "tooltip": "Provider for text embedding model."}), - "embedding_model": ((), {"tooltip": "Specific embedding model to use."}), - "tool_input": ("OMNI", {"default": None, "tooltip": "Additional input for the selected tool."}), - "prime_directives": ("STRING", {"forceInput": True, "tooltip": "System message or prime directive for the AI assistant."}), - "external_api_key":("STRING", {"default": "", "tooltip": "If this is not empty, it will be used instead of the API key from the .env file. Make sure it is empty to use the .env file."}), - "top_k_search": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Find top scored image(s) from RAG."}), - "aspect_ratio": (["1:1", "9:16", "16:9"], {"default": "16:9", "tooltip": "Select the aspect ratio for the image."}), - "enable_RAG": ("BOOLEAN", {"default": False, "label_on": "RAG is Enabled", "label_off": "RAG is Disabled", "tooltip": "Enables Retrieval-Augmented Generation for enhanced context."}), - "query_type": (["global", "local", "naive", "colpali", "colqwen2", "colpali-v1.2"], {"default": "global", "tooltip": "Selects the type of query strategy for RAG."}), - "preset": (["Default", "Detailed", "Quick", "Bullet", "Comprehensive", "High-Level", "Focused"], {"default": "Default"}), - }, - "hidden": { - "model": ("STRING", {"default": ""}), - "rag_root_dir": ("STRING", {"default": "rag_data"}) - } - } - - @classmethod - def IS_CHANGED(cls, **kwargs): - node = cls() - - llm_provider = kwargs.get('llm_provider', node.llm_provider) - embedding_provider = kwargs.get('embedding_provider', node.embedding_provider) - base_ip = kwargs.get('base_ip', node.base_ip) - port = kwargs.get('port', node.port) - query_type = kwargs.get('query_type', node.query_type) - external_api_key = kwargs.get('external_api_key', '') - task = kwargs.get('task', node.task) - - # Determine which API key to use - def get_api_key_with_fallback(provider, external_api_key): - if external_api_key and external_api_key != '': - return external_api_key - try: - # print(f"Using {provider} API key from .env file") - api_key = get_api_key(f"{provider.upper()}_API_KEY", provider) - # print(f" {api_key} API key for {provider} found in .env file") - return api_key - - except ValueError: - return None - - api_key = get_api_key_with_fallback(llm_provider, external_api_key) - - # Check for changes - llm_provider_changed = llm_provider != node.llm_provider - embedding_provider_changed = embedding_provider != node.embedding_provider - api_key_changed = external_api_key != node.external_api_key - base_ip_changed = base_ip != node.base_ip - port_changed = port != node.port - query_type_changed = query_type != node.query_type - task_changed = task != node.task - - # Always fetch new models if the provider, API key, base_ip, or port has changed - if llm_provider_changed or api_key_changed or base_ip_changed or port_changed: - try: - new_llm_models = get_models(llm_provider, base_ip, port, api_key) - except Exception as e: - print(f"Error fetching LLM models: {e}") - new_llm_models = [] - llm_model_changed = new_llm_models != node.llm_model - else: - llm_model_changed = False - - if embedding_provider_changed or api_key_changed or base_ip_changed or port_changed: - try: - new_embedding_models = get_models(embedding_provider, base_ip, port, api_key) - except Exception as e: - print(f"Error fetching embedding models: {e}") - new_embedding_models = [] - embedding_model_changed = new_embedding_models != node.embedding_model - else: - embedding_model_changed = False - - if (llm_provider_changed or embedding_provider_changed or llm_model_changed or - embedding_model_changed or query_type_changed or task_changed or api_key_changed or - base_ip_changed or port_changed): - - node.llm_provider = llm_provider - node.embedding_provider = embedding_provider - node.base_ip = base_ip - node.port = port - node.external_api_key = external_api_key - node.query_type = query_type - node.task = task - - if llm_model_changed: - node.llm_model = new_llm_models - if embedding_model_changed: - node.embedding_model = new_embedding_models - - # Update other attributes - for attr in ['seed', 'random', 'history_steps', 'clear_history', 'mode', - 'keep_alive', 'tool', 'enable_RAG', 'preset']: - setattr(node, attr, kwargs.get(attr, getattr(node, attr))) - - return True - - return False - - RETURN_TYPES = ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK") - RETURN_NAMES = ("Question", "Response", "Negative", "Tool_Output", "Retrieved_Image", "Mask") - - OUTPUT_TOOLTIPS = ( - "The original input question or prompt.", - "The generated response from the language model.", - "The negative prompt used (if applicable) for guiding image generation.", - "Output from the selected tool, which can be code or any other data type.", - "An image retrieved by the RAG system, if applicable.", - "Mask image generated by Florence tasks." - ) - FUNCTION = "process_chat_wrapper" - OUTPUT_NODE = True - CATEGORY = "ImpactFrames💥🎞️" - DESCRIPTION = "ComfyUI, Support API and Local LLM providers and RAG capabilities. Processes text prompts, handles image inputs, and integrates with different language models and indexing strategies." - +# IFChatPromptNode.py +import os +import sys +import json +import torch +import shutil +import base64 +import platform +import importlib +import subprocess +import numpy as np +import folder_paths +from PIL import Image +import yaml +from io import BytesIO +import asyncio +from typing import List, Union, Dict, Any, Tuple, Optional +from .agent_tool import AgentTool +from .send_request import send_request +#from .transformers_api import TransformersModelManager +import tempfile +import threading +from aiohttp import web +from .graphRAG_module import GraphRAGapp +from .colpaliRAG_module import colpaliRAGapp +from .superflorence import FlorenceModule +from .utils import get_api_key, get_models, validate_models, clean_text, process_mask, load_placeholder_image, process_images_for_comfy +#from byaldi import RAGMultiModalModel +# Set up logging +import logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) +# Add the ComfyUI directory to the Python path +comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.insert(0, comfy_path) + +ifchat_prompt_node = None + +try: + from server import PromptServer + + @PromptServer.instance.routes.post("/IF_ChatPrompt/get_llm_models") + async def get_llm_models_endpoint(request): + data = await request.json() + llm_provider = data.get("llm_provider") + engine = llm_provider + base_ip = data.get("base_ip") + port = data.get("port") + external_api_key = data.get("external_api_key") + + logger.debug(f"Received request for LLM models. Provider: {llm_provider}, External API key provided: {bool(external_api_key)}") + + if external_api_key: + api_key = external_api_key + logger.debug("Using provided external LLM API key") + else: + api_key_name = f"{llm_provider.upper()}_API_KEY" + try: + api_key = get_api_key(api_key_name, engine) + logger.debug("Using API key from environment or .env file") + except ValueError: + logger.warning(f"No API key found for {llm_provider}. Attempting to proceed without an API key.") + api_key = None + + models = get_models(engine, base_ip, port, api_key) + logger.debug(f"Fetched {len(models)} models for {llm_provider}") + return web.json_response(models) + + @PromptServer.instance.routes.post("/IF_ChatPrompt/get_embedding_models") + async def get_embedding_models_endpoint(request): + data = await request.json() + embedding_provider = data.get("embedding_provider") + engine = embedding_provider + base_ip = data.get("base_ip") + port = data.get("port") + external_api_key = data.get("external_api_key") + + logger.debug(f"Received request for LLM models. Provider: {embedding_provider}, External API key provided: {bool(external_api_key)}") + + if external_api_key: + api_key = external_api_key + logger.debug("Using provided external LLM API key") + else: + api_key_name = f"{embedding_provider.upper()}_API_KEY" + try: + api_key = get_api_key(api_key_name, engine) + logger.debug("Using API key from environment or .env file") + except ValueError: + logger.warning(f"No API key found for {embedding_provider}. Attempting to proceed without an API key.") + api_key = None + + models = get_models(engine, base_ip, port, api_key) + logger.debug(f"Fetched {len(models)} models for {embedding_provider}") + return web.json_response(models) + + @PromptServer.instance.routes.post("/IF_ChatPrompt/upload_file") + async def upload_file_route(request): + try: + reader = await request.multipart() + + rag_folder_name = None + file_content = None + filename = None + + # Process all parts of the multipart request + while True: + part = await reader.next() + if part is None: + break + if part.name == "rag_root_dir": + rag_folder_name = await part.text() + elif part.filename: + filename = part.filename + file_content = await part.read() + + if not filename or not file_content or not rag_folder_name: + return web.json_response({"status": "error", "message": "Missing file, filename, or RAG folder name"}) + + node = IFChatPrompt() + input_dir = os.path.join(node.rag_dir, rag_folder_name, "input") + + if not os.path.exists(input_dir): + os.makedirs(input_dir, exist_ok=True) + + file_path = os.path.join(input_dir, filename) + + with open(file_path, 'wb') as f: + f.write(file_content) + + logger.info(f"File uploaded to: {file_path}") + return web.json_response({"status": "success", "message": f"File uploaded to: {file_path}"}) + + except Exception as e: + logger.error(f"Error in upload_file_route: {str(e)}") + return web.json_response({"status": "error", "message": f"Error uploading file: {str(e)}"}) + + @PromptServer.instance.routes.post("/IF_ChatPrompt/setup_and_initialize") + async def setup_and_initialize(request): + global ifchat_prompt_node + + data = await request.json() + folder_name = data.get('folder_name', 'rag_data') + + if ifchat_prompt_node is None: + ifchat_prompt_node = IFChatPrompt() + + init_result = await ifchat_prompt_node.graphrag_app.setup_and_initialize_folder(folder_name, data) + + ifchat_prompt_node.rag_folder_name = folder_name + ifchat_prompt_node.colpali_app.set_rag_root_dir(folder_name) + + return web.json_response(init_result) + + @PromptServer.instance.routes.post("/IF_ChatPrompt/run_indexer") + async def run_indexer_endpoint(request): + try: + data = await request.json() + logger.debug(f"Received indexing request with data: {data}") + + global ifchat_prompt_node # Access the global instance + + # Set the rag_root_dir in both modules using the global instance + ifchat_prompt_node.graphrag_app.set_rag_root_dir(data.get('rag_folder_name')) + ifchat_prompt_node.colpali_app.set_rag_root_dir(data.get('rag_folder_name')) + + query_type = data.get('mode_type') + logger.debug(f"Query type: {query_type}") + + logger.debug(f"Starting indexing process for query type: {query_type}") + + # Initialize the colpali_model before calling insert, using the global instance + if query_type == 'colpali' or query_type == 'colqwen2' or query_type == 'colpali-v1.2': + _ = ifchat_prompt_node.colpali_app.get_colpali_model(query_type) # This will load or retrieve the cached model + result = await ifchat_prompt_node.colpali_app.insert() + else: + result = await ifchat_prompt_node.graphrag_app.insert() + + logger.debug(f"Indexing process completed with result: {result}") + + if result: + return web.json_response({"status": "success", "message": f"Indexing complete for {query_type}"}) + else: + return web.json_response({"status": "error", "message": "Indexing failed. Check server logs."}, status=500) + + except Exception as e: + logger.error(f"Error in run_indexer_endpoint: {str(e)}") + return web.json_response({"status": "error", "message": f"Error during indexing: {str(e)}"}, status=500) + + @PromptServer.instance.routes.post("/IF_ChatPrompt/process_chat") + async def process_chat_endpoint(request): + try: + data = await request.json() + + # Set default values for required arguments if not provided + defaults = { + "prompt": "", + "assistant": "Cortana", # Default assistant + "neg_prompt": "Default", # Default negative prompt + "embellish_prompt": "Default", # Default embellishment + "style_prompt": "Default", # Default style + "llm_provider": "ollama", + "llm_model": "", + "base_ip": "localhost", + "port": "11434", + "embedding_model": "", + "embedding_provider": "sentence_transformers" + } + + # Update data with defaults for missing keys + for key, default_value in defaults.items(): + if key not in data: + data[key] = default_value + + global ifchat_prompt_node + result = await ifchat_prompt_node.process_chat(**data) + + return web.json_response(result) + + except Exception as e: + logger.error(f"Error in process_chat_endpoint: {str(e)}") + return web.json_response({ + "status": "error", + "message": f"Error processing chat: {str(e)}", + "Question": data.get("prompt", ""), + "Response": f"Error: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": None, + "Mask": None + }, status=500) + + @PromptServer.instance.routes.post("/IF_ChatPrompt/load_index") + async def load_index_route(request): + try: + data = await request.json() + index_name = data.get('rag_folder_name') + query_type = data.get('query_type') + + if not index_name: + logger.error("No index name provided in the request.") + return web.json_response({ + "status": "error", + "message": "No index name provided" + }) + + # Check if index exists in .byaldi directory + byaldi_index_path = os.path.join(".byaldi", index_name) + if not os.path.exists(byaldi_index_path): + logger.error(f"Index not found in .byaldi: {byaldi_index_path}") + return web.json_response({ + "status": "error", + "message": f"Index {index_name} does not exist" + }) + + try: + global ifchat_prompt_node + if ifchat_prompt_node is None: + logger.debug("Initializing IFChatPrompt instance.") + ifchat_prompt_node = IFChatPrompt() + + if query_type in ['colpali', 'colqwen2', 'colpali-v1.2']: + logger.debug(f"Loading model for query type: {query_type}") + + # Clear any existing cached index + ifchat_prompt_node.colpali_app.cleanup_index() + + # First get the base model + colpali_model = ifchat_prompt_node.colpali_app.get_colpali_model(query_type) + + if colpali_model: + # Load and cache the new index + model = await ifchat_prompt_node.colpali_app._prepare_model(query_type, index_name) + if not model: + raise ValueError("Failed to load and cache index") + + # Set the RAG root directory + ifchat_prompt_node.colpali_app.set_rag_root_dir(index_name) + + logger.info(f"Successfully loaded and cached index: {index_name}") + return web.json_response({ + "status": "success", + "message": f"Successfully loaded index: {index_name}", + "rag_root_dir": index_name + }) + else: + logger.error("Failed to initialize ColPali model.") + raise ValueError("Failed to initialize ColPali model") + + else: + logger.error(f"Unsupported query type: {query_type}") + return web.json_response({ + "status": "error", + "message": f"Query type {query_type} not supported for loading indexes" + }) + + except Exception as e: + logger.error(f"Error loading index {index_name}: {str(e)}") + return web.json_response({ + "status": "error", + "message": f"Error loading index: {str(e)}" + }) + + except Exception as e: + logger.error(f"Error in load_index_route: {str(e)}") + return web.json_response({ + "status": "error", + "message": f"Error processing request: {str(e)}" + }) + + # Add this with the other routes + @PromptServer.instance.routes.post("/IF_ChatPrompt/delete_index") + async def delete_index_route(request): + try: + data = await request.json() + index_name = data.get('rag_folder_name') + + if not index_name: + return web.json_response({ + "status": "error", + "message": "No index name provided" + }) + + # Path to the index + index_path = os.path.join(".byaldi", index_name) + + if not os.path.exists(index_path): + return web.json_response({ + "status": "error", + "message": f"Index {index_name} does not exist" + }) + + # Delete the index directory + try: + shutil.rmtree(index_path) + logger.info(f"Successfully deleted index: {index_name}") + return web.json_response({ + "status": "success", + "message": f"Successfully deleted index: {index_name}" + }) + except Exception as e: + logger.error(f"Error deleting index {index_name}: {str(e)}") + return web.json_response({ + "status": "error", + "message": f"Error deleting index: {str(e)}" + }) + + except Exception as e: + logger.error(f"Error in delete_index_route: {str(e)}") + return web.json_response({ + "status": "error", + "message": f"Error processing request: {str(e)}" + }) + +except AttributeError: + print("PromptServer.instance not available. Skipping route decoration for IF_ChatPrompt.") + +class IFChatPrompt: + + def __init__(self): + self.base_ip = "localhost" + self.port = "11434" + self.llm_provider = "ollama" + self.embedding_provider = "sentence_transformers" + self.llm_model = "" + self.embedding_model = "" + self.assistant = "None" + self.random = False + + self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "rag") + self.presets_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "presets") + + self.stop_file = os.path.join(self.presets_dir, "stop_strings.json") + self.assistants_file = os.path.join(self.presets_dir, "assistants.json") + self.neg_prompts_file = os.path.join(self.presets_dir, "neg_prompts.json") + self.embellish_prompts_file = os.path.join(self.presets_dir, "embellishments.json") + self.style_prompts_file = os.path.join(self.presets_dir, "style_prompts.json") + self.tasks_file = os.path.join(self.presets_dir, "florence_prompts.json") + self.agents_dir = os.path.join(self.presets_dir, "agents") + + self.agent_tools = self.load_agent_tools() + self.stop_strings = self.load_presets(self.stop_file) + self.assistants = self.load_presets(self.assistants_file) + self.neg_prompts = self.load_presets(self.neg_prompts_file) + self.embellish_prompts = self.load_presets(self.embellish_prompts_file) + self.style_prompts = self.load_presets(self.style_prompts_file) + self.florence_prompts = self.load_presets(self.tasks_file) + + self.keep_alive = False + self.seed = 94687328150 + self.messages = [] + self.history_steps = 10 + self.external_api_key = "" + self.tool_input = "" + self.prime_directives = None + self.rag_folder_name = "rag_data" + self.graphrag_app = GraphRAGapp() + self.colpali_app = colpaliRAGapp() + self.fix_json = True + self.cached_colpali_model = None + self.florence_app = FlorenceModule() + self.florence_models = {} + self.query_type = "global" + self.enable_RAG = False + self.clear_history = False + self.mode = False + self.tool = "None" + self.preset = "Default" + self.precision = "fp16" + self.task = None + self.attention = "sdpa" + self.aspect_ratio = "16:9" + self.top_k_search = 3 + + self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "placeholder.png") + + if not os.path.exists(self.placeholder_image_path): + placeholder = Image.new('RGB', (512, 512), color=(73, 109, 137)) + os.makedirs(os.path.dirname(self.placeholder_image_path), exist_ok=True) + placeholder.save(self.placeholder_image_path) + + def load_presets(self, file_path): + with open(file_path, 'r') as f: + presets = json.load(f) + return presets + + def load_agent_tools(self): + os.makedirs(self.agents_dir, exist_ok=True) + agent_tools = {} + try: + for filename in os.listdir(self.agents_dir): + if filename.endswith('.json'): + full_path = os.path.join(self.agents_dir, filename) + with open(full_path, 'r') as f: + try: + data = json.load(f) + if 'output_type' not in data: + data['output_type'] = None + agent_tool = AgentTool(**data) + agent_tool.load() + if agent_tool._class_instance is not None: + if agent_tool.python_function: + agent_tools[agent_tool.name] = agent_tool + else: + print(f"Warning: Agent tool {agent_tool.name} in {filename} does not have a python_function defined.") + else: + print(f"Failed to create class instance for {filename}") + except json.JSONDecodeError: + print(f"Error: Invalid JSON in {filename}") + except Exception as e: + print(f"Error loading {filename}: {str(e)}") + return agent_tools + except Exception as e: + print(f"Warning: Error accessing agent tools directory: {str(e)}") + return {} + + async def process_chat( + self, + prompt, + llm_provider, + llm_model, + base_ip, + port, + assistant, + neg_prompt, + embellish_prompt, + style_prompt, + embedding_model, + embedding_provider, + external_api_key="", + temperature=0.7, + max_tokens=2048, + seed=0, + random=False, + history_steps=10, + keep_alive=False, + top_k=40, + top_p=0.2, + repeat_penalty=1.1, + stop_string=None, + images=None, + mode=True, + clear_history=False, + text_cleanup=True, + tool=None, + tool_input=None, + prime_directives=None, + enable_RAG=False, + query_type="global", + preset="Default", + rag_folder_name=None, + task=None, + fill_mask=False, + output_mask_select="", + precision="fp16", + attention="sdpa", + aspect_ratio="16:9", + top_k_search=3 + ): + + if external_api_key != "": + llm_api_key = external_api_key + else: + llm_api_key = get_api_key(f"{llm_provider.upper()}_API_KEY", llm_provider) + + print(f"LLM API key: {llm_api_key[:5]}...") + if prime_directives is not None: + system_message_str = prime_directives + else: + system_message = self.assistants.get(assistant, "") + system_message_str = json.dumps(system_message) + + # Validate LLM model + validate_models(llm_model, llm_provider, "LLM", base_ip, port, llm_api_key) + + # Validate embedding model + validate_models(embedding_model, embedding_provider, "embedding", base_ip, port, llm_api_key) + + # Handle history + if clear_history: + self.messages = [] + elif history_steps > 0: + self.messages = self.messages[-history_steps:] + + messages = self.messages + + # Handle stop + if stop_string is None or stop_string == "None": + stop_content = None + else: + stop_content = self.stop_strings.get(stop_string, None) + stop = stop_content + + if llm_provider not in ["ollama", "llamacpp", "vllm", "lmstudio", "gemeni"]: + if llm_provider == "kobold": + stop = stop_content + \ + ["\n\n\n\n\n"] if stop_content else ["\n\n\n\n\n"] + elif llm_provider == "mistral": + stop = stop_content + \ + ["\n\n"] if stop_content else ["\n\n"] + else: + stop = stop_content if stop_content else None + # Handle tools + try: + if tool and tool != "None": + selected_tool = self.agent_tools.get(tool) + if not selected_tool: + raise ValueError(f"Invalid agent tool selected: {tool}") + + # Prepare tool execution message + tool_message = f"Execute the {tool} tool with the following input: {prompt}" + system_prompt = json.dumps(selected_tool.system_prompt) + + # Send request to LLM for tool execution + generated_text =await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images, + model=llm_model, + system_message=system_prompt, + user_message=tool_message, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + ) + # Parse the generated text for function calls + function_call = None + try: + response_data = json.loads(generated_text) + if 'function_call' in response_data: + function_call = response_data['function_call'] + generated_text = response_data['content'] + except json.JSONDecodeError: + pass # The response wasn't JSON, so it's just the generated text + + # Execute the tool with the LLM's response + tool_args = { + "input": prompt, + "llm_response": generated_text, + "function_call": function_call, + "omni_input": tool_input, + "name": selected_tool.name, + "description": selected_tool.description, + "system_prompt": selected_tool.system_prompt + } + tool_result = selected_tool.execute(tool_args) + + # Update messages + messages.append({"role": "user", "content": prompt}) + messages.append({ + "role": "assistant", + "content": json.dumps(tool_result) if isinstance(tool_result, dict) else str(tool_result) + }) + + # Process the tool output + if isinstance(tool_result, dict): + if "error" in tool_result: + generated_text = f"Error in {tool}: {tool_result['error']}" + tool_output = None + elif selected_tool.output_type and selected_tool.output_type in tool_result: + tool_output = tool_result[selected_tool.output_type] + generated_text = f"Agent {tool} executed successfully. Output generated." + else: + tool_output = tool_result + generated_text = str(tool_output) + else: + tool_output = tool_result + generated_text = str(tool_output) + + return { + "Question": prompt, + "Response": generated_text, + "Negative": self.neg_prompts.get(neg_prompt, ""), + "Tool_Output": tool_output, + "Retrieved_Image": None # No image retrieved in tool execution + } + else: + response = await self.generate_response( + enable_RAG, + query_type, + prompt, + preset, + llm_provider, + base_ip, + port, + images, + llm_model, + system_message_str, + messages, + temperature, + max_tokens, + random, + top_k, + top_p, + repeat_penalty, + stop, + seed, + keep_alive, + llm_api_key, + task, + fill_mask, + output_mask_select, + precision, + attention + ) + + generated_text = response.get("Response") + selected_neg_prompt_name = neg_prompt + omni = response.get("Tool_Output") + retrieved_image = response.get("Retrieved_Image") + retrieved_mask = response.get("Mask") + + + # Update messages + messages.append({"role": "user", "content": prompt}) + messages.append({"role": "assistant", "content": generated_text}) + + text_result = str(generated_text).strip() + + if mode: + embellish_content = self.embellish_prompts.get(embellish_prompt, "").strip() + style_content = self.style_prompts.get(style_prompt, "").strip() + + lines = [line.strip() for line in text_result.split('\n') if line.strip()] + combined_prompts = [] + + for line in lines: + if text_cleanup: + line = clean_text(line) + formatted_line = f"{embellish_content} {line} {style_content}".strip() + combined_prompts.append(formatted_line) + + combined_prompt = "\n".join(formatted_line for formatted_line in combined_prompts) + # Handle negative prompts + if selected_neg_prompt_name == "AI_Fill": + try: + neg_system_message = self.assistants.get("NegativePromptEngineer") + if not neg_system_message: + logger.error("NegativePromptEngineer not found in assistants configuration") + negative_prompt = "Error: NegativePromptEngineer not configured" + else: + user_message = f"Generate negative prompts for the following prompt:\n{text_result}" + + system_message_str = json.dumps(neg_system_message) + + logger.info(f"Requesting negative prompts for prompt: {text_result[:100]}...") + + neg_response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=None, + llm_model=llm_model, + system_message=system_message_str, + user_message=user_message, + messages=[], # Fresh context for negative generation + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key + ) + + logger.debug(f"Received negative prompt response: {neg_response}") + + if neg_response: + negative_lines = [] + for line in neg_response.split('\n'): + line = line.strip() + if line: + negative_lines.append(line) + + while len(negative_lines) < len(lines): + negative_lines.append(negative_lines[-1] if negative_lines else "") + negative_lines = negative_lines[:len(lines)] + + negative_prompt = "\n".join(negative_lines) + else: + negative_prompt = "Error: Empty response from LLM" + except Exception as e: + logger.error(f"Error generating negative prompts: {str(e)}", exc_info=True) + negative_prompt = f"Error generating negative prompts: {str(e)}" + + elif neg_prompt != "None": + neg_content = self.neg_prompts.get(neg_prompt, "").strip() + negative_lines = [neg_content for _ in range(len(lines))] + negative_prompt = "\n".join(negative_lines) + else: + negative_prompt = "" + + else: + combined_prompt = text_result + negative_prompt = "" + + try: + if isinstance(retrieved_image, torch.Tensor): + # Ensure it's in the correct format (B, C, H, W) + if retrieved_image.dim() == 3: # Single image (C, H, W) + image_tensor = retrieved_image.unsqueeze(0) # Add batch dimension + else: + image_tensor = retrieved_image # Already batched + + # Create matching batch masks + batch_size = image_tensor.shape[0] + height = image_tensor.shape[2] + width = image_tensor.shape[3] + + # Create white masks (all ones) for each image in batch + mask_tensor = torch.ones((batch_size, 1, height, width), + dtype=torch.float32, + device=image_tensor.device) + + if retrieved_mask is not None: + # If we have masks, process them to match the batch + if isinstance(retrieved_mask, torch.Tensor): + if retrieved_mask.dim() == 3: # Single mask + mask_tensor = retrieved_mask.unsqueeze(0) + else: + mask_tensor = retrieved_mask + else: + # Process retrieved_mask if it's not a tensor + mask_tensor = process_mask(retrieved_mask, image_tensor) + else: + image_tensor, default_mask_tensor = process_images_for_comfy( + retrieved_image, + self.placeholder_image_path + ) + mask_tensor = default_mask_tensor + + if retrieved_mask is not None: + mask_tensor = process_mask(retrieved_mask, image_tensor) + return ( + prompt, + combined_prompt, + negative_prompt, + omni, + image_tensor, + mask_tensor, + ) + + except Exception as e: + logger.error(f"Exception in image processing: {str(e)}", exc_info=True) + placeholder_image, placeholder_mask = load_placeholder_image(self.placeholder_image_path) + return ( + prompt, + f"Error: {str(e)}", + "", + None, + placeholder_image, + placeholder_mask + ) + + except Exception as e: + logger.error(f"Exception occurred in process_chat: {str(e)}", exc_info=True) + placeholder_image, placeholder_mask = load_placeholder_image(self.placeholder_image_path) + return ( + prompt, + f"Error: {str(e)}", + "", + None, + placeholder_image, + placeholder_mask + ) + + async def generate_response( + self, + enable_RAG, + query_type, + prompt, + preset, + llm_provider, + base_ip, + port, + images, + llm_model, + system_message_str, + messages, + temperature, + max_tokens, + random, + top_k, + top_p, + repeat_penalty, + stop, + seed, + keep_alive, + llm_api_key, + task=None, + fill_mask=False, + output_mask_select="", + precision="fp16", + attention="sdpa", + ): + response_strategies = { + "graphrag": self.graphrag_app.query, + "colpali": self.colpali_app.query, + "florence": self.florence_app.run_florence, + "normal": lambda: send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images, + llm_model=llm_model, + system_message=system_message_str, + user_message=prompt, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + tools=None, + tool_choice=None, + precision=precision, + attention=attention + ), + } + florence_tasks = list(self.florence_prompts.keys()) + if enable_RAG: + if query_type == "colpali" or query_type == "colpali-v1.2" or query_type == "colqwen2": + strategy = "colpali" + else: # For "global", "local", and "naive" query types + strategy = "graphrag" + elif task and task.lower() != 'none' and task in florence_tasks: + strategy = "florence" + else: + strategy = "normal" + + print(f"Strategy: {strategy}") + + try: + if strategy == "colpali": + # Ensure the model is loaded before querying + if self.cached_colpali_model is None: + self.cached_colpali_model = self.colpali_app.get_colpali_model(query_type) + response = await response_strategies[strategy](prompt=prompt, query_type=query_type, system_message_str=system_message_str) + return response + elif strategy == "graphrag": + response = await response_strategies[strategy](prompt=prompt, query_type=query_type, preset=preset) + return { + "Question": prompt, + "Response": response[0], + "Negative": "", + "Tool_Output": response[1], + "Retrieved_Image": None, + "Mask": None + } + elif strategy == "florence": + task_content = self.florence_prompts.get(task, "") + response = await response_strategies[strategy]( + images=images, + task=task, + task_prompt=task_content, + llm_model=llm_model, + precision=precision, + attention=attention, + fill_mask=fill_mask, + output_mask_select=output_mask_select, + keep_alive=keep_alive, + max_new_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repeat_penalty, + seed=seed, + text_input=prompt, + ) + print("Florence response:", response) + return response + else: + response = await response_strategies[strategy]() + print("Normal response:", response) + return { + "Question": prompt, + "Response": response, + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": None, + "Mask": None + } + + except Exception as e: + logger.error(f"Error processing strategy: {strategy}") + return { + "Question": prompt, + "Response": f"Error processing task: {str(e)}", + "Negative": "", + "Tool_Output": {"error": str(e)}, + "Retrieved_Image": None, + "Mask": None + } + + def process_chat_wrapper(self, *args, **kwargs): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + logger.debug(f"process_chat_wrapper kwargs: {kwargs}") + logger.debug(f"External LLM API Key: {kwargs.get('external_api_key', 'Not provided')}") + return loop.run_until_complete(self.process_chat(*args, **kwargs)) + + @classmethod + def INPUT_TYPES(cls): + node = cls() + return { + "required": { + "prompt": ("STRING", {"multiline": True, "default": "", "tooltip": "The main text input for the chat or query."}), + "llm_provider": (["xai","llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "transformers"], {"default": node.llm_provider, "tooltip": "The provider of the language model to be used."}), + "llm_model": ((), {"tooltip": "The specific language model to be used for processing."}), + "base_ip": ("STRING", {"default": node.base_ip, "tooltip": "IP address of the LLM server."}), + "port": ("STRING", {"default": node.port, "tooltip": "Port number for the LLM server connection."}), + }, + "optional": { + "images": ("IMAGE", {"list": True, "tooltip": "Input image(s) for visual processing or context."}), + "precision": (['fp16','bf16','fp32','int8','int4'],{"default": 'bf16', "tooltip": "Select preccision on Transformer models."}), + "attention": (['flash_attention_2','sdpa','xformers', 'Shrek_COT_o1'],{"default": 'sdpa', "tooltip": "Select attention mechanism on Transformer models."}), + "assistant": ([name for name in node.assistants.keys()], {"default": node.assistant, "tooltip": "The pre-defined assistant personality to use for responses."}), + "tool": (["None"] + [name for name in node.agent_tools.keys()], {"default": "None", "tooltip": "Selects a specific tool or agent for task execution."}), + "temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1, "tooltip": "Controls randomness in output generation. Higher values increase creativity but may reduce coherence."}), + "max_tokens": ("INT", {"default": 2048, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Maximum number of tokens to generate in the response."}), + "top_k": ("INT", {"default": 40, "min": 0, "max": 100, "tooltip": "Limits the next token selection to the K most likely tokens."}), + "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.1, "tooltip": "Cumulative probability cutoff for token selection."}), + "repeat_penalty": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Penalizes repetition in generated text."}), + "stop_string": ([name for name in node.stop_strings.keys()], {"tooltip": "Specifies a string at which text generation should stop."}), + "seed": ("INT", {"default": 94687328150, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Random seed for reproducible outputs."}), + "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature", "tooltip": "Toggles between using a fixed seed or temperature-based randomness."}), + "history_steps": ("INT", {"default": 10, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Number of previous conversation turns to consider for context."}), + "clear_history": ("BOOLEAN", {"default": False, "label_on": "Clear History", "label_off": "Keep History", "tooltip": "Option to clear or retain conversation history."}), + "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps Model on Memory", "label_off": "Unloads Model from Memory", "tooltip": "Determines whether to keep the model loaded in memory between calls."}), + "text_cleanup": ("BOOLEAN", {"default": True, "label_on": "Clean Response", "label_off": "Raw Text", "tooltip": "Applies text cleaning to the generated output."}), + "mode": ("BOOLEAN", {"default": False, "label_on": "Using SD Mode", "label_off": "Using Chat Mode", "tooltip": "Switches between Stable Diffusion prompt generation and standard chat mode."}), + "embellish_prompt": ([name for name in node.embellish_prompts.keys()], {"tooltip": "Adds pre-defined embellishments to the prompt."}), + "style_prompt": ([name for name in node.style_prompts.keys()], {"tooltip": "Applies a pre-defined style to the prompt."}), + "neg_prompt": ([name for name in node.neg_prompts.keys()], {"tooltip": "Adds a negative prompt to guide what should be avoided in generation."}), + "fill_mask": ("BOOLEAN", {"default": False, "label_on": "Fill Mask", "label_off": "No Fill", "tooltip": "Option to fill masks for Florence tasks."}), + "output_mask_select": ("STRING", {"default": ""}), + "task": ([name for name in node.florence_prompts.keys()], {"default": "None", "tooltip": "Select a Florence task."}), + "embedding_provider": (["llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "sentence_transformers"], {"default": node.embedding_provider, "tooltip": "Provider for text embedding model."}), + "embedding_model": ((), {"tooltip": "Specific embedding model to use."}), + "tool_input": ("OMNI", {"default": None, "tooltip": "Additional input for the selected tool."}), + "prime_directives": ("STRING", {"forceInput": True, "tooltip": "System message or prime directive for the AI assistant."}), + "external_api_key":("STRING", {"default": "", "tooltip": "If this is not empty, it will be used instead of the API key from the .env file. Make sure it is empty to use the .env file."}), + "top_k_search": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Find top scored image(s) from RAG."}), + "aspect_ratio": (["1:1", "9:16", "16:9"], {"default": "16:9", "tooltip": "Select the aspect ratio for the image."}), + "enable_RAG": ("BOOLEAN", {"default": False, "label_on": "RAG is Enabled", "label_off": "RAG is Disabled", "tooltip": "Enables Retrieval-Augmented Generation for enhanced context."}), + "query_type": (["global", "local", "naive", "colpali", "colqwen2", "colpali-v1.2"], {"default": "global", "tooltip": "Selects the type of query strategy for RAG."}), + "preset": (["Default", "Detailed", "Quick", "Bullet", "Comprehensive", "High-Level", "Focused"], {"default": "Default"}), + }, + "hidden": { + "model": ("STRING", {"default": ""}), + "rag_root_dir": ("STRING", {"default": "rag_data"}) + } + } + + @classmethod + def IS_CHANGED(cls, **kwargs): + node = cls() + + llm_provider = kwargs.get('llm_provider', node.llm_provider) + embedding_provider = kwargs.get('embedding_provider', node.embedding_provider) + base_ip = kwargs.get('base_ip', node.base_ip) + port = kwargs.get('port', node.port) + query_type = kwargs.get('query_type', node.query_type) + external_api_key = kwargs.get('external_api_key', '') + task = kwargs.get('task', node.task) + + # Determine which API key to use + def get_api_key_with_fallback(provider, external_api_key): + if external_api_key and external_api_key != '': + return external_api_key + try: + # print(f"Using {provider} API key from .env file") + api_key = get_api_key(f"{provider.upper()}_API_KEY", provider) + # print(f" {api_key} API key for {provider} found in .env file") + return api_key + + except ValueError: + return None + + api_key = get_api_key_with_fallback(llm_provider, external_api_key) + + # Check for changes + llm_provider_changed = llm_provider != node.llm_provider + embedding_provider_changed = embedding_provider != node.embedding_provider + api_key_changed = external_api_key != node.external_api_key + base_ip_changed = base_ip != node.base_ip + port_changed = port != node.port + query_type_changed = query_type != node.query_type + task_changed = task != node.task + + # Always fetch new models if the provider, API key, base_ip, or port has changed + if llm_provider_changed or api_key_changed or base_ip_changed or port_changed: + try: + new_llm_models = get_models(llm_provider, base_ip, port, api_key) + except Exception as e: + print(f"Error fetching LLM models: {e}") + new_llm_models = [] + llm_model_changed = new_llm_models != node.llm_model + else: + llm_model_changed = False + + if embedding_provider_changed or api_key_changed or base_ip_changed or port_changed: + try: + new_embedding_models = get_models(embedding_provider, base_ip, port, api_key) + except Exception as e: + print(f"Error fetching embedding models: {e}") + new_embedding_models = [] + embedding_model_changed = new_embedding_models != node.embedding_model + else: + embedding_model_changed = False + + if (llm_provider_changed or embedding_provider_changed or llm_model_changed or + embedding_model_changed or query_type_changed or task_changed or api_key_changed or + base_ip_changed or port_changed): + + node.llm_provider = llm_provider + node.embedding_provider = embedding_provider + node.base_ip = base_ip + node.port = port + node.external_api_key = external_api_key + node.query_type = query_type + node.task = task + + if llm_model_changed: + node.llm_model = new_llm_models + if embedding_model_changed: + node.embedding_model = new_embedding_models + + # Update other attributes + for attr in ['seed', 'random', 'history_steps', 'clear_history', 'mode', + 'keep_alive', 'tool', 'enable_RAG', 'preset']: + setattr(node, attr, kwargs.get(attr, getattr(node, attr))) + + return True + + return False + + RETURN_TYPES = ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK") + RETURN_NAMES = ("Question", "Response", "Negative", "Tool_Output", "Retrieved_Image", "Mask") + + OUTPUT_TOOLTIPS = ( + "The original input question or prompt.", + "The generated response from the language model.", + "The negative prompt used (if applicable) for guiding image generation.", + "Output from the selected tool, which can be code or any other data type.", + "An image retrieved by the RAG system, if applicable.", + "Mask image generated by Florence tasks." + ) + FUNCTION = "process_chat_wrapper" + OUTPUT_NODE = True + CATEGORY = "ImpactFrames💥🎞️" + DESCRIPTION = "ComfyUI, Support API and Local LLM providers and RAG capabilities. Processes text prompts, handles image inputs, and integrates with different language models and indexing strategies." + diff --git a/IFDisplayTextNode.py b/IFDisplayTextNode.py index f874754..b52f235 100644 --- a/IFDisplayTextNode.py +++ b/IFDisplayTextNode.py @@ -1,4 +1,9 @@ import sys +import logging +from typing import Optional + +# Initialize logger +logger = logging.getLogger(__name__) class IFDisplayText: def __init__(self): @@ -27,7 +32,11 @@ def INPUT_TYPES(cls): OUTPUT_NODE = True CATEGORY = "ImpactFrames💥🎞️" - def display_text(self, text, select): + def display_text(self, text: Optional[str], select): + if text is None: + logger.error("Received None for text input in display_text.") + return "" # Or handle appropriately + print("==================") print("IF_AI_tool_output:") print("==================") diff --git a/IFDisplayTextWildcardNode.py b/IFDisplayTextWildcardNode.py new file mode 100644 index 0000000..5b47047 --- /dev/null +++ b/IFDisplayTextWildcardNode.py @@ -0,0 +1,413 @@ +#IFDisplayTextWildcardNode.py +import os +import sys +import yaml +import json +import random +import re +import itertools +import threading +import traceback +from pathlib import Path +import folder_paths +from execution import ExecutionBlocker + +class IFDisplayTextWildcard: + def __init__(self): + self.wildcards = {} + self._execution_count = None + self.wildcard_lock = threading.Lock() + + # Initialize paths + self.base_path = folder_paths.base_path + self.presets_dir = os.path.join(self.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "presets") + self.wildcards_dir = os.path.join(self.presets_dir, "wildcards") # Updated path + + # Load wildcards + self.wildcards = self.load_wildcards() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": ("STRING", {"forceInput": True}), + "select": ("INT", { + "default": 0, + "min": 0, + "max": sys.maxsize, + "step": 1, + }), + "counter": ("INT", { + "default": -1, + "min": -1, + "max": 999999, + "step": 1, + "display": "number", + }), + }, + "optional": { + "dynamic_prompt": ("STRING", { + "multiline": True, + "defaultInput": True, + "placeholder": "Enter dynamic variables e.g. prefix={val1|val2}" + }), + "max_variants": ("INT", { + "default": 10, + "min": 1, + "max": 1000, + "step": 1, + }), + "wildcard_mode": ("BOOLEAN", { + "default": False, + "display": "button" + }), + } + } + + RETURN_TYPES = ("STRING", "STRING", "INT", "STRING") + RETURN_NAMES = ("text", "text_list", "count", "selected") + OUTPUT_IS_LIST = (False, True, False, False) + FUNCTION = "display_text" + OUTPUT_NODE = True + CATEGORY = "ImpactFrames💥🎞️" + + def load_wildcards(self): + """Load wildcards from YAML/JSON files in the specified directory""" + wildcard_dict = {} + wildcards_path = self.wildcards_dir + + def wildcard_normalize(x): + return x.replace("\\", "/").replace(' ', '-').lower() + + def read_wildcard_file(file_path): + """Read wildcard definitions from a file""" + _, ext = os.path.splitext(file_path) + key = wildcard_normalize(os.path.basename(file_path).split('.')[0]) + try: + if ext.lower() in ['.yaml', '.yml']: + with open(file_path, 'r', encoding="utf-8") as f: + yaml_data = yaml.safe_load(f) + # Flatten the nested dictionary into wildcard_dict + self.flatten_wildcard_dict(yaml_data, key, wildcard_dict) + elif ext.lower() == '.json': + with open(file_path, 'r', encoding="utf-8") as f: + json_data = json.load(f) + self.flatten_wildcard_dict(json_data, key, wildcard_dict) + else: + print(f"Unsupported file format for wildcards: {file_path}") + except Exception as e: + print(f"Error loading {file_path}: {e}") + + # Read all files in the wildcards directory + for file_name in os.listdir(wildcards_path): + file_path = os.path.join(wildcards_path, file_name) + if os.path.isfile(file_path): + read_wildcard_file(file_path) + + print("Loaded Wildcards:") + for key, values in wildcard_dict.items(): + print(f"{key}: {values}") + return wildcard_dict + + def flatten_wildcard_dict(self, data, parent_key, wildcard_dict): + """Flatten nested dictionaries into wildcard_dict with composite keys and aggregate top-level values.""" + def wildcard_normalize(x): + return x.replace("\\", "/").replace(' ', '-').lower() + + if isinstance(data, dict): + combined_values = [] + for k, v in data.items(): + new_key = f"{parent_key}/{k}" + self.flatten_wildcard_dict(v, new_key, wildcard_dict) + + # Collect all values from subcategories + if isinstance(v, dict) or isinstance(v, list): + sub_values = self.get_all_nested_values({new_key: v}) + combined_values.extend(sub_values) + else: + combined_values.append(v) + + # Move assignment outside the for loop + wildcard_dict[parent_key] = combined_values + elif isinstance(data, list): + wildcard_dict[parent_key] = data + else: + key = wildcard_normalize(parent_key) + wildcard_dict[key] = [data] + + def get_wildcard_values(self, keyword, pattern_modifier, wildcard_dict): + """Retrieve wildcard values based on the pattern modifier.""" + keys_to_search = [keyword] + + if pattern_modifier == '/**': + # Include all nested keys + keys_to_search = [k for k in wildcard_dict.keys() if k.startswith(f"{keyword}/")] + elif pattern_modifier == '/*': + # Include immediate child keys + keys_to_search = [k for k in wildcard_dict.keys() if k.startswith(f"{keyword}/") and '/' not in k[len(keyword)+1:]] + + values = [] + for key in keys_to_search: + vals = wildcard_dict.get(key, []) + if isinstance(vals, list): + values.extend(vals) + else: + values.append(vals) + return values + + def replace_wildcard(self, string, wildcard_dict): + """Replace wildcards in the given string with appropriate values.""" + pattern = r"__(.+?)(/\*{1,2})?__" # {{ edit: Updated regex to capture wildcard and pattern modifiers }} + matches = re.findall(pattern, string) + + replacements_found = False + + for match in matches: + keyword, pattern_modifier = match + pattern_modifier = pattern_modifier or '' + + keyword_normalized = keyword.lower().replace('\\', '/').replace(' ', '-') + + # Handle pattern modifiers + if pattern_modifier == '/**': + values = self.get_wildcard_values(keyword_normalized, '/**', wildcard_dict) + elif pattern_modifier == '/*': + values = self.get_wildcard_values(keyword_normalized, '/*', wildcard_dict) + else: + values = wildcard_dict.get(keyword_normalized, []) + + if not values: + print(f"Error: Wildcard __{keyword}{pattern_modifier}__ not found.") + continue + + replacement = random.choice(values) + string = string.replace(f"__{keyword}{pattern_modifier}__", replacement, 1) + replacements_found = True + + return string, replacements_found + + def process(self, text, dynamic_vars, seed=None): + """Process the text, replacing options and wildcards""" + + if seed is not None: + random.seed(seed) + random_gen = random.Random(seed) + + local_wildcard_dict = self.wildcards.copy() + dynamic_vars_lower = {k.lower(): v for k, v in dynamic_vars.items()} + local_wildcard_dict.update(dynamic_vars_lower) + + def is_numeric_string(input_str): + return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None + + def safe_float(x): + if is_numeric_string(x): + return float(x) + else: + return 1.0 + + def replace_options(string): + replacements_found = False + + def replace_option(match): + nonlocal replacements_found + content = match.group(1) + options = [] + weight_pattern = r'(?:(\d+(?:\.\d+)?)::)?(.*)' + for opt in content.split('|'): + opt = opt.strip() + m = re.match(weight_pattern, opt) + weight = float(m.group(1)) if m.group(1) else 1.0 + value = m.group(2).strip() + options.append((value, weight)) + + # Handle combination syntax + num_select = 1 + select_sep = ' ' + multi_select_pattern = content.split('$$') + if len(multi_select_pattern) > 1: + range_str = multi_select_pattern[0] + options_str = '$$'.join(multi_select_pattern[1:]) + options = [] + for opt in options_str.split('|'): + opt = opt.strip() + m = re.match(weight_pattern, opt) + weight = float(m.group(1)) if m.group(1) else 1.0 + value = m.group(2).strip() + options.append((value, weight)) + + if '-' in range_str: + min_select, max_select = map(int, range_str.split('-')) + num_select = random_gen.randint(min_select, max_select) + else: + num_select = int(range_str) + + total_weight = sum(weight for value, weight in options) + normalized_weights = [weight / total_weight for value, weight in options] + + if num_select > len(options): + selected_items = [value for value, weight in options] + for _ in range(num_select - len(options)): + selected_items.append(random_gen.choice(selected_items)) + else: + selected_items = random_gen.choices( + [value for value, weight in options], + weights=normalized_weights, + k=num_select + ) + + replacement = select_sep.join(selected_items) + replacements_found = True + return replacement + + pattern = r'\{([^{}]*?)\}' + replaced_string = re.sub(pattern, replace_option, string) + return replaced_string, replacements_found + + # Pass 1: replace options + pass1, is_replaced1 = replace_options(text) + + while is_replaced1: + pass1, is_replaced1 = replace_options(pass1) + + # Pass 2: replace wildcards using local_wildcard_dict + text, is_replaced2 = self.replace_wildcard(pass1, local_wildcard_dict) + + stop_unwrap = not is_replaced1 and not is_replaced2 + + return text + + def process_text(self, text, dynamic_vars, max_variants, seed=None): + """Process text replacing wildcards and dynamic variables""" + output_prompts = [] + base_prompts = [p.strip() for p in text.split("\n") if p.strip()] + if not base_prompts: + base_prompts = [""] + + for base_prompt in base_prompts: + try: + for _ in range(max_variants): + processed_prompt = self.process(base_prompt, dynamic_vars, seed) + output_prompts.append(processed_prompt) + except ValueError as e: + print(f"Error: {e}") + continue + + # Ensure unique prompts and respect max_variants + output_prompts = list(dict.fromkeys(output_prompts))[:max_variants] + return output_prompts + + def parse_dynamic_variables(self, text): + """Parse dynamic variables in formats: + prefix={val1|val2}, **prefix**={val1|val2}, __prefix__={val1|val2} + """ + variables = {} + # Match both formats + patterns = [ + r'(\w+)=\{([^}]+)\}', # prefix={val1|val2} + r'\*\*(\w+)\*\*=\{([^}]+)\}', # **prefix**={val1|val2} + r'__(\w+)__=\{([^}]+)\}' # __prefix__={val1|val2} + ] + + for pattern in patterns: + matches = re.finditer(pattern, text) + for match in matches: + category = match.group(1).strip().lower() + values = [v.strip() for v in match.group(2).split('|')] + variables[category] = values + return variables + + def display_text(self, text, select=0, counter=-1, dynamic_prompt="", max_variants=10, wildcard_mode=False): + """Main node processing function""" + try: + # Handle counter + if self._execution_count is None or self._execution_count > counter: + self._execution_count = counter + + if self._execution_count == 0: + return {"ui": {"string": ["Execution blocked: Counter reached 0"]}, + "result": ExecutionBlocker("Counter reached 0")} + + # Parse dynamic variables if provided + dynamic_vars = {} + if dynamic_prompt: + print(f"Processing dynamic prompt: {dynamic_prompt}") + dynamic_vars = self.parse_dynamic_variables(dynamic_prompt) + print(f"Parsed dynamic variables: {dynamic_vars}") + + # Process text + output_prompts = [] + if wildcard_mode: + output_prompts = self.process_text(text, dynamic_vars, max_variants) + else: + output_prompts = [text] + + # Ensure at least one prompt + if not output_prompts: + output_prompts = [text] + + count = len(output_prompts) + selected = output_prompts[select % count] if count > 0 else text + + # Debug output + print("\nIF_AI_tool_output:") + print("==================") + print(f"Mode: {'Wildcard' if wildcard_mode else 'Normal'}") + print(f"Counter: {self._execution_count}") + print(f"Dynamic vars: {dynamic_vars}") + print(f"Variants generated: {count}") + for i, p in enumerate(output_prompts): + print(f"[{i+1}/{count}] {p}") + print("------------------") + print("==================") + + # Update counter if needed + if self._execution_count > 0: + self._execution_count -= 1 + + return { + "ui": {"string": output_prompts}, + "result": (text, output_prompts, count, selected) + } + + except Exception as e: + print(f"Error in display_text: {str(e)}") + traceback.print_exc() + return {"ui": {"string": [f"Error: {str(e)}"]}, + "result": ExecutionBlocker(f"Error: {str(e)}")} + + @classmethod + def IS_CHANGED(cls, text, select, counter, **kwargs): + return counter + + def get_all_nested_values(self, data): + """Recursively get all values from nested structure""" + values = [] + if isinstance(data, dict): + for v in data.values(): + values.extend(self.get_all_nested_values(v)) + elif isinstance(data, list): + for item in data: + if isinstance(item, dict) or isinstance(item, list): + values.extend(self.get_all_nested_values(item)) + else: + values.append(item) + else: + values.append(data) + return values + + def get_root_values(self, data): + """Get only root level values""" + values = [] + if isinstance(data, dict): + for v in data.values(): + if isinstance(v, list): + values.extend(v) + elif isinstance(v, str): + values.append(v) + elif isinstance(data, list): + values.extend(data) + elif isinstance(data, str): + values.append(data) + return values + diff --git a/IFImagePromptNode.py b/IFImagePromptNode.py index 768a7c5..9c38f2a 100644 --- a/IFImagePromptNode.py +++ b/IFImagePromptNode.py @@ -1,388 +1,715 @@ -import json -import requests -import base64 -import textwrap -import io -import os -from io import BytesIO -from PIL import Image -import torch -import tempfile -from torchvision.transforms.functional import to_pil_image -import folder_paths -import sys - -# Add the ComfyUI directory to the Python path -comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) -sys.path.insert(0, comfy_path) - -try: - from server import PromptServer - from aiohttp import web - - @PromptServer.instance.routes.post("/IF_ImagePrompt/get_models") - async def get_models_endpoint(request): - data = await request.json() - engine = data.get("engine") - base_ip = data.get("base_ip") - port = data.get("port") - - node = IFImagePrompt() - models = node.get_models(engine, base_ip, port) - return web.json_response(models) -except AttributeError: - print("PromptServer.instance not available. Skipping route decoration for IF_ImagePrompt.") - async def get_models_endpoint(request): - # Fallback implementation - return web.json_response({"error": "PromptServer.instance not available"}) - -class IFImagePrompt: - - RETURN_TYPES = ("STRING", "STRING", "STRING",) - RETURN_NAMES = ("Question", "Response", "Negative",) - FUNCTION = "describe_picture" - OUTPUT_NODE = True - CATEGORY = "ImpactFrames💥🎞️" - - @classmethod - def INPUT_TYPES(cls): - node = cls() - return { - "required": { - "image": ("IMAGE", ), - "image_prompt": ("STRING", {"multiline": True, "default": ""}), - "base_ip": ("STRING", {"default": node.base_ip}), - "port": ("STRING", {"default": node.port}), - "engine": (["ollama", "openai", "anthropic"], {"default": node.engine}), - #"selected_model": (node.get_models("node.engine", node.base_ip, node.port), {}), - "selected_model": ((), {}), - "profile": ([name for name in node.profiles.keys()], {"default": node.profile}), - "embellish_prompt": ([name for name in node.embellish_prompts.keys()], {}), - "style_prompt": ([name for name in node.style_prompts.keys()], {}), - "neg_prompt": ([name for name in node.neg_prompts.keys()], {}), - "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.1}), - }, - "optional": { - "max_tokens": ("INT", {"default": 160, "min": 1, "max": 8192}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature"}), - "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps_Model", "label_off": "Unloads_Model"}), - }, - "hidden": { - "model": ("STRING", {"default": ""}), - }, - } - - @classmethod - def IS_CHANGED(cls, engine, base_ip, port, keep_alive, profile): - node = cls() - if engine != node.engine or base_ip != node.base_ip or port != node.port or node.selected_model != node.get_models(engine, base_ip, port) or keep_alive != node.keep_alive or profile != profile: - node.engine = engine - node.base_ip = base_ip - node.port = port - node.selected_model = node.get_models(engine, base_ip, port) - node.keep_alive = keep_alive - node.profile = profile - return True - return False - - def __init__(self): - self.base_ip = "localhost" - self.port = "11434" - self.engine = "ollama" - self.selected_model = "" - self.profile = "IF_PromptMKR_IMG" - self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.presets_dir = os.path.join(os.path.dirname(__file__), "presets") - self.profiles_file = os.path.join(self.presets_dir, "profiles.json") - self.profiles = self.load_presets(self.profiles_file) - self.neg_prompts_file = os.path.join(self.presets_dir, "neg_prompts.json") - self.embellish_prompts_file = os.path.join(self.presets_dir, "embellishments.json") - self.style_prompts_file = os.path.join(self.presets_dir, "style_prompts.json") - self.neg_prompts = self.load_presets(self.neg_prompts_file) - self.embellish_prompts = self.load_presets(self.embellish_prompts_file) - self.style_prompts = self.load_presets(self.style_prompts_file) - - - def load_presets(self, file_path): - with open(file_path, 'r') as f: - presets = json.load(f) - return presets - - def get_api_key(self, api_key_name, engine): - if engine != "ollama": - api_key = os.getenv(api_key_name) - if api_key: - return api_key - else: - print(f'you are using ollama as the engine, no api key is required') - - - def get_models(self, engine, base_ip, port): - if engine == "ollama": - api_url = f'http://{base_ip}:{port}/api/tags' - try: - response = requests.get(api_url) - response.raise_for_status() - models = [model['name'] for model in response.json().get('models', [])] - return models - except Exception as e: - print(f"Failed to fetch models from Ollama: {e}") - return [] - elif engine == "anthropic": - return ["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"] - elif engine == "openai": - return ["gpt-4-0125-preview", "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4-1106-vision-preview", "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106"] - else: - print(f"Unsupported engine - {engine}") - return [] - - def tensor_to_image(self, tensor): - # Ensure tensor is on CPU - tensor = tensor.cpu() - # Normalize tensor 0-255 and convert to byte - image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() - # Create PIL image - image = Image.fromarray(image_np, mode='RGB') - return image - - - def prepare_messages(self, image_prompt, profile): - profile_selected = self.profiles.get(profile, "") - empty_image_prompt = "Make a visual prompt for the following Image:" - filled_image_prompt = textwrap.dedent("""\ - Act as a visual prompt maker with the following guidelines: - - Describe the image in vivid detail. - - Break keywords by commas. - - Provide high-quality, non-verbose, coherent, concise, and not superfluous descriptions. - - Focus solely on the visual elements of the picture; avoid art commentaries or intentions. - - Construct the prompt by describing framing, subjects, scene elements, background, aesthetics. - - Limit yourself up to 7 keywords per component - - Be varied and creative. - - Always reply on the same line, use around 100 words long. - - Do not enumerate or enunciate components. - - Do not include any additional information in the response. - The following is an illustartive example for you to see how to construct a prompt your prompts should follow this format but always coherent to the subject worldbuilding or setting and consider the elements relationship: - 'Epic, Cover Art, Full body shot, dynamic angle, A Demon Hunter, standing, lone figure, glow eyes, deep purple light, cybernetic exoskeleton, sleek, metallic, glowing blue accents, energy weapons. Fighting Demon, grotesque creature, twisted metal, glowing red eyes, sharp claws, Cyber City, towering structures, shrouded haze, shimmering energy. Ciberpunk, dramatic lighthing, highly detailed. ' - """) - if image_prompt.strip() == "": - system_message = filled_image_prompt - user_message = empty_image_prompt - else: - system_message = profile_selected - user_message = image_prompt - - return system_message, user_message - - - def describe_picture(self, image, engine, selected_model, base_ip, port, image_prompt, embellish_prompt, style_prompt, neg_prompt, temperature, max_tokens, seed, random, keep_alive, profile): - - embellish_content = self.embellish_prompts.get(embellish_prompt, "") - style_content = self.style_prompts.get(style_prompt, "") - neg_content = self.neg_prompts.get(neg_prompt, "") - - # Check the type of the 'image' object - if isinstance(image, torch.Tensor): - # Convert the tensor to a PIL image - pil_image = self.tensor_to_image(image) - elif isinstance(image, Image.Image): - pil_image = image - elif isinstance(image, str) and os.path.isfile(image): - pil_image = Image.open(image) - else: - print(f"Invalid image type: {type(image)}. Expected torch.Tensor, PIL.Image, or file path.") - return "Invalid image type" - - # Convert the PIL image to base64 - buffered = BytesIO() - pil_image.save(buffered, format="PNG") - base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') - - available_models = self.get_models(engine, base_ip, port) - if available_models is None or selected_model not in available_models: - error_message = f"Invalid model selected: {selected_model} for engine {engine}. Available models: {available_models}" - print(error_message) - raise ValueError(error_message) - - system_message, user_message = self.prepare_messages(image_prompt, profile) - - try: - generated_text = self.send_request(engine, selected_model, base_ip, port, base64_image, system_message, user_message, temperature, max_tokens, seed, random, keep_alive) - description = f"{embellish_content} {generated_text} {style_content}".strip() - return image_prompt, description, neg_content - except Exception as e: - print(f"Exception occurred: {e}") - return "Exception occurred while processing image." - - - def send_request(self, engine, selected_model, base_ip, port, base64_image, system_message, user_message, temperature, max_tokens, seed, random, keep_alive): - if engine == "anthropic": - anthropic_api_key = self.get_api_key("ANTHROPIC_API_KEY", engine) - anthropic_headers = { - "x-api-key": anthropic_api_key, - "anthropic-version": "2023-06-01", - "Content-Type": "application/json" - } - - data = { - "model": selected_model, - "system": system_message, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": base64_image - } - }, - {"type": "text", "text": user_message} - ] - } - ], - "temperature": temperature, - "max_tokens": max_tokens - } - - api_url = 'https://api.anthropic.com/v1/messages' - - response = requests.post(api_url, headers=anthropic_headers, json=data) - if response.status_code == 200: - response_data = response.json() - messages = response_data.get('content', []) - generated_text = ''.join([msg.get('text', '') for msg in messages if msg.get('type') == 'text']) - return generated_text - else: - print(f"Error: Request failed with status code {response.status_code}, Response: {response.text}") - return "Failed to fetch response from Anthropic." - - elif engine == "openai": - openai_api_key = self.get_api_key("OPENAI_API_KEY", engine) - openai_headers = { - "Authorization": f"Bearer {openai_api_key}", - "Content-Type": "application/json" - } - if random == True: - data = { - "model": selected_model, - "messages": [ - { - "role": "system", - "content": system_message - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": user_message - }, - { - "type": "image_url", - "image_url": f"data:image/png;base64,{base64_image}" - } - ] - } - ], - "temperature": temperature, - "seed": seed, - "max_tokens": max_tokens - } - else: - data = { - "model": selected_model, - "messages": [ - { - "role": "system", - "content": system_message - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": user_message - }, - { - "type": "image_url", - "image_url": f"data:image/png;base64,{base64_image}" - } - ] - } - ], - "temperature": temperature, - "max_tokens": max_tokens - } - - api_url = 'https://api.openai.com/v1/chat/completions' - response = requests.post(api_url, headers=openai_headers, json=data) - if response.status_code == 200: - response_data = response.json() - print("Debug Response:", response_data) - choices = response_data.get('choices', []) - if choices: - choice = choices[0] - message = choice.get('message', {}) - generated_text = message.get('content', '') - return generated_text - else: - print("No valid choices in the response.") - print("Full response:", response.text) - return "No valid response generated for the image." - else: - print(f"Failed to fetch response, status code: {response.status_code}") - print("Full response:", response.text) - return "Failed to fetch response from OpenAI." - - else: - api_url = f'http://{base_ip}:{port}/api/generate' - if random == True: - data = { - "model": selected_model, - "system": system_message, - "prompt": user_message, - "stream": False, - "images": [base64_image], - "options": { - "temperature": temperature, - "seed": seed, - "num_ctx": max_tokens - }, - "keep_alive": -1 if keep_alive else 0, - } - else: - data = { - "model": selected_model, - "system": system_message, - "prompt": user_message, - "stream": False, - "images": [base64_image], - "options": { - "temperature": temperature, - "num_ctx": max_tokens, - }, - "keep_alive": -1 if keep_alive else 0, - } - - ollama_headers = {"Content-Type": "application/json"} - response = requests.post(api_url, headers=ollama_headers, json=data) - if response.status_code == 200: - response_data = response.json() - prompt_response = response_data.get('response', 'No response text found') - - # Ensure there is a response to construct the full description - if prompt_response != 'No response text found': - return prompt_response - else: - return "No valid response generated for the image." - else: - print(f"Failed to fetch response, status code: {response.status_code}") - return "Failed to fetch response from Ollama." - - - -NODE_CLASS_MAPPINGS = {"IF_ImagePrompt": IFImagePrompt} -NODE_DISPLAY_NAME_MAPPINGS = {"IF_ImagePrompt": "IF Image to Prompt🖼️"} +# IFImagePromptNode.py +import os +import sys +import json +import torch +import asyncio +import requests +from PIL import Image +from io import BytesIO +from typing import List, Dict, Any, Optional, Union, Tuple +import folder_paths +from .omost import omost_function +from .send_request import send_request +from .utils import ( + get_api_key, + get_models, + process_images_for_comfy, + process_mask, + clean_text, + load_placeholder_image, + validate_models, +) + +# Add ComfyUI directory to path +comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.insert(0, comfy_path) + +# Set up logging +import logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +try: + from server import PromptServer + from aiohttp import web + + @PromptServer.instance.routes.post("/IF_ImagePrompt/get_llm_models") + async def get_llm_models_endpoint(request): + try: + data = await request.json() + llm_provider = data.get("llm_provider") + engine = llm_provider + base_ip = data.get("base_ip") + port = data.get("port") + external_api_key = data.get("external_api_key") + + if external_api_key: + api_key = external_api_key + else: + api_key_name = f"{llm_provider.upper()}_API_KEY" + try: + api_key = get_api_key(api_key_name, engine) + except ValueError: + api_key = None + + node = IFImagePrompt() + models = node.get_models(engine, base_ip, port, api_key) + return web.json_response(models) + + except Exception as e: + print(f"Error in get_llm_models_endpoint: {str(e)}") + return web.json_response([], status=500) + + @PromptServer.instance.routes.post("/IF_ImagePrompt/add_routes") + async def add_routes_endpoint(request): + return web.json_response({"status": "success"}) + +except AttributeError: + print("PromptServer.instance not available. Skipping route decoration for IF_ImagePrompt.") + +class IFImagePrompt: + def __init__(self): + self.strategies = "normal" + # Initialize paths and load presets + self.base_path = folder_paths.base_path + self.presets_dir = os.path.join(self.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "presets") + + # Load preset configurations + self.profiles = self.load_presets(os.path.join(self.presets_dir, "profiles.json")) + self.neg_prompts = self.load_presets(os.path.join(self.presets_dir, "neg_prompts.json")) + self.embellish_prompts = self.load_presets(os.path.join(self.presets_dir, "embellishments.json")) + self.style_prompts = self.load_presets(os.path.join(self.presets_dir, "style_prompts.json")) + self.stop_strings = self.load_presets(os.path.join(self.presets_dir, "stop_strings.json")) + + # Initialize placeholder image path + self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "placeholder.png") + + # Default values + + self.base_ip = "localhost" + self.port = "11434" + self.engine = "xai" + self.selected_model = "" + self.profile = "IF_PromptMKR_IMG" + self.messages = [] + self.keep_alive = False + self.seed = 94687328150 + self.history_steps = 10 + self.external_api_key = "" + self.preset = "Default" + self.precision = "fp16" + self.attention = "sdpa" + self.Omni = None + self.mask = None + self.aspect_ratio = "1:1" + self.keep_alive = False + self.clear_history = False + self.random = False + self.max_tokens = 2048 + self.temperature = 0.7 + self.top_k = 40 + self.top_p = 0.9 + self.repeat_penalty = 1.1 + self.stop = None + self.batch_count = 4 + + @classmethod + def INPUT_TYPES(cls): + node = cls() + return { + "required": { + "images": ("IMAGE", {"list": True}), # Primary image input + "llm_provider": (["xai","llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "transformers"], {}), + "llm_model": ((), {}), + "base_ip": ("STRING", {"default": "localhost"}), + "port": ("STRING", {"default": "11434"}), + "user_prompt": ("STRING", {"multiline": True}), + }, + "optional": { + "strategy": (["normal", "omost", "create", "edit", "variations"], {"default": "normal"}), + "mask": ("MASK", {}), + "prime_directives": ("STRING", {"forceInput": True, "tooltip": "The system prompt for the LLM."}), + "profiles": (["None"] + list(cls().profiles.keys()), {"default": "None", "tooltip": "The pre-defined system_prompt from the json profile file on the presets folder you can edit or make your own will be listed here."}), + "embellish_prompt": (list(cls().embellish_prompts.keys()), {"tooltip": "The pre-defined embellishment from the json embellishments file on the presets folder you can edit or make your own will be listed here."}), + "style_prompt": (list(cls().style_prompts.keys()), {"tooltip": "The pre-defined style from the json style_prompts file on the presets folder you can edit or make your own will be listed here."}), + "neg_prompt": (list(cls().neg_prompts.keys()), {"tooltip": "The pre-defined negative prompt from the json neg_prompts file on the presets folder you can edit or make your own will be listed here."}), + "stop_string": (list(cls().stop_strings.keys()), {"tooltip": "Specifies a string at which text generation should stop."}), + "max_tokens": ("INT", {"default": 2048, "min": 1, "max": 8192, "tooltip": "Maximum number of tokens to generate in the response."}), + "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature", "tooltip": "Toggles between using a fixed seed or temperature-based randomness."}), + "seed": ("INT", {"default": 0, "tooltip": "Random seed for reproducible outputs."}), + "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "tooltip": "Controls randomness in output generation. Higher values increase creativity but may reduce coherence."}), + "top_k": ("INT", {"default": 40, "tooltip": "Limits the next token selection to the K most likely tokens."}), + "top_p": ("FLOAT", {"default": 0.9, "tooltip": "Cumulative probability cutoff for token selection."}), + "repeat_penalty": ("FLOAT", {"default": 1.1, "tooltip": "Penalizes repetition in generated text."}), + "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps Model on Memory", "label_off": "Unloads Model from Memory", "tooltip": "Determines whether to keep the model loaded in memory between calls."}), + "clear_history": ("BOOLEAN", {"default": False, "label_on": "Clear History", "label_off": "Keep History", "tooltip": "Determines whether to clear the history between calls."}), + "history_steps": ("INT", {"default": 10, "tooltip": "Number of steps to keep in history."}), + "aspect_ratio": (["1:1", "16:9", "4:5", "3:4", "5:4", "9:16"], {"default": "1:1", "tooltip": "Aspect ratio for the generated images."}), + "batch_count": ("INT", {"default": 4, "tooltip": "Number of images to generate. only for create, edit and variations strategies."}), + "external_api_key": ("STRING", {"default": "", "tooltip": "If this is not empty, it will be used instead of the API key from the .env file. Make sure it is empty to use the .env file."}), + "precision": (["fp16", "fp32", "bf16"], {"tooltip": "Select preccision on Transformer models."}), + "attention": (["sdpa", "flash_attention_2", "xformers"], {"tooltip": "Select attention mechanism on Transformer models."}), + "Omni": ("OMNI", {"default": None, "tooltip": "Additional input for the selected tool."}), + } + } + + RETURN_TYPES = ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK") + RETURN_NAMES = ("question", "response", "negative", "omni", "generated_images", "mask") + + FUNCTION = "process_image_wrapper" + OUTPUT_NODE = True + CATEGORY = "ImpactFrames💥🎞️" + + def get_models(self, engine, base_ip, port, api_key=None): + return get_models(engine, base_ip, port, api_key) + + def load_presets(self, file_path: str) -> Dict[str, Any]: + try: + with open(file_path, 'r') as f: + return json.load(f) + except Exception as e: + print(f"Error loading presets from {file_path}: {e}") + return {} + + def validate_outputs(self, outputs): + """Helper to validate output types match expectations""" + if len(outputs) != len(self.RETURN_TYPES): + raise ValueError( + f"Expected {len(self.RETURN_TYPES)} outputs, got {len(outputs)}" + ) + + for i, (output, expected_type) in enumerate(zip(outputs, self.RETURN_TYPES)): + if output is None and expected_type in ["IMAGE", "MASK"]: + raise ValueError( + f"Output {i} ({self.RETURN_NAMES[i]}) cannot be None for type {expected_type}" + ) + + async def generate_negative_prompts( + self, + prompt: str, + llm_provider: str, + llm_model: str, + base_ip: str, + port: str, + config: dict, + messages: list = None + ) -> List[str]: + """ + Generate negative prompts for the given input prompt. + + Args: + prompt: Input prompt text + llm_provider: LLM provider name + llm_model: Model name + base_ip: API base IP + port: API port + config: Dict containing generation parameters like seed, temperature etc + messages: Optional message history + + Returns: + List of generated negative prompts + """ + try: + if not prompt: + return [] + + # Get system message for negative prompts + neg_system_message = self.profiles.get("IF_NegativePromptEngineer", "") + + # Generate negative prompts + neg_response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=None, + llm_model=llm_model, + system_message=neg_system_message, + user_message=f"Generate negative prompts for:\n{prompt}", + messages=messages or [], + **config + ) + + if not neg_response: + return [] + + # Split into lines and clean up + neg_lines = [line.strip() for line in neg_response.split('\n') if line.strip()] + + # Match number of prompts + num_prompts = len(prompt.split('\n')) + if len(neg_lines) < num_prompts: + neg_lines.extend([neg_lines[-1] if neg_lines else ""] * (num_prompts - len(neg_lines))) + + return neg_lines[:num_prompts] + + except Exception as e: + logger.error(f"Error generating negative prompts: {str(e)}") + return ["Error generating negative prompt"] * num_prompts + + @classmethod + def IS_CHANGED(cls, **kwargs): + return float("NaN") + + async def process_image( + self, + llm_provider: str, + llm_model: str, + base_ip: str, + port: str, + user_prompt: str, + strategy: str = "normal", + images=None, + prime_directives: Optional[str] = None, + profiles: Optional[str] = None, + embellish_prompt: Optional[str] = None, + style_prompt: Optional[str] = None, + neg_prompt: Optional[str] = None, + stop_string: Optional[str] = None, + max_tokens: int = 2048, + seed: int = 0, + random: bool = False, + temperature: float = 0.8, + top_k: int = 40, + top_p: float = 0.9, + repeat_penalty: float = 1.1, + keep_alive: bool = False, + clear_history: bool = False, + history_steps: int = 10, + external_api_key: str = "", + precision: str = "fp16", + attention: str = "sdpa", + Omni: Optional[str] = None, + aspect_ratio: str = "1:1", + mask: Optional[torch.Tensor] = None, + batch_count: int = 4, + **kwargs + ) -> Union[str, Dict[str, Any]]: + try: + # Initialize variables at the start + formatted_response = None + generated_images = None + generated_masks = None + tool_output = None + + if external_api_key != "": + llm_api_key = external_api_key + else: + llm_api_key = get_api_key(f"{llm_provider.upper()}_API_KEY", llm_provider) + print(f"LLM API key: {llm_api_key[:5]}...") + + # Validate LLM model + validate_models(llm_model, llm_provider, "LLM", base_ip, port, llm_api_key) + + # Handle history + if clear_history: + self.messages = [] + elif history_steps > 0: + self.messages = self.messages[-history_steps:] + + messages = self.messages + + # Handle stop + if stop_string is None or stop_string == "None": + stop_content = None + else: + stop_content = self.stop_strings.get(stop_string, None) + stop = stop_content + + if llm_provider not in ["ollama", "llamacpp", "vllm", "lmstudio", "gemeni"]: + if llm_provider == "kobold": + stop = stop_content + \ + ["\n\n\n\n\n"] if stop_content else ["\n\n\n\n\n"] + elif llm_provider == "mistral": + stop = stop_content + \ + ["\n\n"] if stop_content else ["\n\n"] + else: + stop = stop_content if stop_content else None + + # Prepare embellishments and styles + embellish_content = self.embellish_prompts.get(embellish_prompt, "").strip() if embellish_prompt else "" + style_content = self.style_prompts.get(style_prompt, "").strip() if style_prompt else "" + neg_content = self.neg_prompts.get(neg_prompt, "").strip() if neg_prompt else "" + profile_content = self.profiles.get(profiles, "") + + # Prepare system prompt + if prime_directives is not None: + system_message_str = prime_directives + else: + system_message_str= json.dumps(profile_content) + + if strategy == "omost": + system_prompt = self.profiles.get("IF_Omost") + messages = [] + # Generate the text using LLM + llm_response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images, + llm_model=llm_model, + system_message=system_prompt, + user_message=user_prompt, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + tools=None, + tool_choice=None, + precision=precision, + attention=attention, + aspect_ratio=aspect_ratio, + strategy="omost", + batch_count=batch_count, + mask=mask, + ) + + # Pass the generated_text to omost_function + tool_args = { + "name": "omost_tool", + "description": "Analyzes images composition and generates a Canvas representation.", + "system_prompt": system_prompt, + "input": user_prompt, + "llm_response": llm_response, + "function_call": None, + "omni_input": Omni + } + + tool_result = await omost_function(tool_args) + + # Process the tool output + if "error" in tool_result: + llm_response = f"Error: {tool_result['error']}" + tool_output = None + else: + tool_output = tool_result.get("canvas_conditioning", "") + llm_response = f"{tool_output}" + cleaned_response = clean_text(llm_response) + + neg_content = self.neg_prompts.get(neg_prompt, "").strip() if neg_prompt else "" + + # Update message history if keeping alive + if keep_alive and cleaned_response: + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": cleaned_response}) + + return { + "Question": user_prompt, + "Response": cleaned_response, + "Negative": neg_content, + "Tool_Output": tool_output, + "Retrieved_Image": None, + "Mask": None + } + elif strategy in ["create", "edit", "variations"]: + resulting_images = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images, + llm_model=llm_model, + system_message=system_prompt, + user_message=user_prompt, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + tools=None, + tool_choice=None, + precision=precision, + attention=attention, + aspect_ratio=aspect_ratio, + strategy=strategy, + batch_count=batch_count, + mask=mask, + ) + if isinstance(resulting_images, dict) and "images" in resulting_images: + generated_images = resulting_images["images"] + generated_masks = None + else: + generated_images = None + generated_masks = None + + try: + if generated_images is not None: + if isinstance(generated_images, torch.Tensor): + # Ensure correct format (B, C, H, W) + image_tensor = generated_images.unsqueeze(0) if generated_images.dim() == 3 else generated_images + + # Create matching batch masks + batch_size = image_tensor.shape[0] + height = image_tensor.shape[2] + width = image_tensor.shape[3] + + # Create default masks + mask_tensor = torch.ones((batch_size, 1, height, width), + dtype=torch.float32, + device=image_tensor.device) + + if generated_masks is not None: + mask_tensor = process_mask(generated_masks, image_tensor) + else: + image_tensor, mask_tensor = process_images_for_comfy(generated_images, self.placeholder_image_path) + mask_tensor = process_mask(generated_masks, image_tensor) if generated_masks is not None else mask_tensor + else: + # No retrieved image - use original or placeholder + if images is not None and len(images) > 0: + image_tensor = images[0] if isinstance(images[0], torch.Tensor) else process_images_for_comfy(images, self.placeholder_image_path)[0] + mask_tensor = torch.ones_like(image_tensor[:1]) # Create mask with same spatial dimensions + else: + image_tensor, mask_tensor = load_placeholder_image(self.placeholder_image_path) + + return { + "Question": user_prompt, + "Response": f"{strategy} image has been successfully generated.", + "Negative": neg_content, + "Tool_Output": None, + "Retrieved_Image": image_tensor, + "Mask": mask_tensor + } + + except Exception as e: + print(f"Error in process_image: {str(e)}") + image_tensor, mask_tensor = load_placeholder_image(self.placeholder_image_path) + return { + "Question": user_prompt, + "Response": f"Error: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": image_tensor, + "Mask": mask_tensor + } + elif strategy == "normal": + try: + formatted_responses = [] + final_prompts = [] + final_negative_prompts = [] + + # Handle images as they come from ComfyUI - no extra processing needed + current_images = images if images is not None else None + + # If mask provided, ensure it matches image dimensions + if mask is not None: + mask_tensor = process_mask(mask, current_images) + else: + # Create default mask if needed + if current_images is not None: + mask_tensor = torch.ones((current_images.shape[0], 1, current_images.shape[2], current_images.shape[3]), + dtype=torch.float32, + device=current_images.device) + else: + _, mask_tensor = load_placeholder_image(self.placeholder_image_path) + + # Iterate over batches + for batch_idx in range(batch_count): + try: + response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=current_images, # Pass images directly + llm_model=llm_model, + system_message=system_message_str, + user_message=user_prompt, + messages=messages, + seed=seed + batch_idx if seed != 0 else seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + precision=precision, + attention=attention, + aspect_ratio=aspect_ratio, + strategy="normal", + batch_count=1, + mask=mask_tensor, + ) + + if not response: + raise ValueError("No response received from LLM API") + + # Clean and process response + cleaned_response = clean_text(response) + final_prompts.append(cleaned_response) + + # Handle negative prompts + if neg_prompt == "AI_Fill": + negative_prompt = await self.generate_negative_prompts( + prompt=cleaned_response, + llm_provider=llm_provider, + llm_model=llm_model, + base_ip=base_ip, + port=port, + config={ + "seed": seed + batch_idx if seed != 0 else seed, + "temperature": temperature, + "max_tokens": max_tokens, + "random": random, + "top_k": top_k, + "top_p": top_p, + "repeat_penalty": repeat_penalty + }, + messages=messages + ) + final_negative_prompts.append(negative_prompt[0] if negative_prompt else neg_content) + else: + final_negative_prompts.append(neg_content) + + formatted_responses.append(cleaned_response) + + except Exception as e: + logger.error(f"Error in batch {batch_idx}: {str(e)}") + formatted_responses.append(f"Error in batch {batch_idx}: {str(e)}") + final_negative_prompts.append(f"Error generating negative prompt for batch {batch_idx}") + + # Combine all responses + formatted_response = "\n".join(final_prompts) + neg_content = "\n".join(final_negative_prompts) + + # Update message history if needed + if keep_alive and formatted_response: + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": formatted_response}) + + return { + "Question": user_prompt, + "Response": formatted_response, + "Negative": neg_content, + "Tool_Output": None, + "Retrieved_Image": current_images, # Return original images + "Mask": mask_tensor + } + + except Exception as e: + logger.error(f"Error in normal strategy: {str(e)}") + # Return original images or placeholder on error + if images is not None: + current_images = images # Use original images + if mask is not None: + current_mask = mask + else: + # Create default mask matching image dimensions + current_mask = torch.ones((current_images.shape[0], 1, current_images.shape[2], current_images.shape[3]), + dtype=torch.float32, + device=current_images.device) + else: + current_images, current_mask = load_placeholder_image(self.placeholder_image_path) + + return { + "Question": user_prompt, + "Response": f"Error in processing: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": current_images, + "Mask": current_mask + } + + except Exception as e: + logger.error(f"Error in process_image: {str(e)}") + return { + "Question": kwargs.get("user_prompt", ""), + "Response": f"Error: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": ( + images[0] + if images is not None and len(images) > 0 + else load_placeholder_image(self.placeholder_image_path)[0] + ), + "Mask": ( + torch.ones_like(images[0][:1]) + if images is not None and len(images) > 0 + else load_placeholder_image(self.placeholder_image_path)[1] + ), + } + + def process_image_wrapper(self, **kwargs): + """Wrapper to handle async execution of process_image""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Ensure images is present in kwargs + if 'images' not in kwargs: + raise ValueError("Input images are required") + + # Ensure all other required parameters are present + required_params = ['llm_provider', 'llm_model', 'base_ip', 'port', 'user_prompt'] + missing_params = [p for p in required_params if p not in kwargs] + if missing_params: + raise ValueError(f"Missing required parameters: {', '.join(missing_params)}") + + # Get the result from process_image + result = loop.run_until_complete(self.process_image(**kwargs)) + + # Extract values in the correct order matching RETURN_TYPES + prompt = result.get("Response", "") # This is the formatted prompt + response = result.get("Question", "") # Original question/prompt + negative = result.get("Negative", "") + omni = result.get("Tool_Output") + retrieved_image = result.get("Retrieved_Image") + mask = result.get("Mask") + + # Ensure we have valid image and mask tensors + if retrieved_image is None or not isinstance(retrieved_image, torch.Tensor): + retrieved_image, mask = load_placeholder_image(self.placeholder_image_path) + + # Ensure mask has correct format + if mask is None: + mask = torch.ones((retrieved_image.shape[0], 1, retrieved_image.shape[2], retrieved_image.shape[3]), + dtype=torch.float32, + device=retrieved_image.device) + + # Return tuple matching RETURN_TYPES order: ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK") + return ( + response, # First STRING (question/prompt) + prompt, # Second STRING (generated response) + negative, # Third STRING (negative prompt) + omni, # OMNI + retrieved_image, # IMAGE + mask # MASK + ) + + except Exception as e: + logger.error(f"Error in process_image_wrapper: {str(e)}") + # Create fallback values + image_tensor, mask_tensor = load_placeholder_image(self.placeholder_image_path) + return ( + kwargs.get("user_prompt", ""), # Original prompt + f"Error: {str(e)}", # Error message as response + "", # Empty negative prompt + None, # No OMNI data + image_tensor, # Placeholder image + mask_tensor # Default mask + ) + +# Node registration +NODE_CLASS_MAPPINGS = { + "IF_ImagePrompt": IFImagePrompt +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "IF_ImagePrompt": "IF Image to Prompt🖼️" +} diff --git a/IFLoadImagesNodeS.py b/IFLoadImagesNodeS.py new file mode 100644 index 0000000..50adb3e --- /dev/null +++ b/IFLoadImagesNodeS.py @@ -0,0 +1,719 @@ +# IFLoadImagesNode.py +import os +import re +import torch +import glob +import hashlib +import logging +import numpy as np +from PIL import Image, ImageOps, ImageSequence +import folder_paths +import shutil +from typing import Tuple, List, Dict, Optional +from server import PromptServer +from aiohttp import web +import json + +logger = logging.getLogger(__name__) + +def numerical_sort_key(path): + """Sort file paths by numerical order in filenames""" + parts = re.split('([0-9]+)', os.path.basename(path)) + parts[1::2] = map(int, parts[1::2]) # Convert number parts to integers + return parts + +class ImageManager: + THUMBNAIL_PREFIX = "thb_" + PATH_SEPARATOR = "___" + SUBFOLDER_PREFIX = "dir" + LEVEL_SEPARATOR = "--" + THUMBNAIL_PREFIX = "thb_" + THUMBNAIL_SIZE = (300, 300) + + VALID_EXTENSIONS = { + "none": {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp"}, + "png": {".png"}, + "jpg": {".jpg", ".jpeg"}, + "webp": {".webp"}, + "gif": {".gif"}, + "bmp": {".bmp"} + } + + @staticmethod + def sanitize_path_component(component: str) -> str: + """Sanitize path components for safe filename use""" + # Replace problematic characters but maintain readability + sanitized = re.sub(r'[\\/:*?"<>|]', '_', component) + return sanitized.strip() + + @staticmethod + def encode_path_to_filename(original_path: str, base_path: str) -> str: + """ + Convert a full file path to an encoded thumbnail filename that preserves hierarchy + Format: thb_dir--level1--level2--level3___filename.ext + """ + try: + # Normalize paths + original_path = os.path.abspath(original_path) + base_path = os.path.abspath(base_path) + + # Get relative path components + rel_path = os.path.relpath(original_path, base_path) + dir_path = os.path.dirname(rel_path) + filename = os.path.basename(rel_path) + + if dir_path and dir_path != '.': + # Split directory path and sanitize each component + dir_parts = [ImageManager.sanitize_path_component(p) + for p in dir_path.split(os.sep)] + # Create encoded directory string + dir_encoded = (f"{ImageManager.SUBFOLDER_PREFIX}" + f"{ImageManager.LEVEL_SEPARATOR}" + f"{ImageManager.LEVEL_SEPARATOR.join(dir_parts)}") + # Combine with filename + return f"{ImageManager.THUMBNAIL_PREFIX}{dir_encoded}{ImageManager.PATH_SEPARATOR}{filename}" + else: + # No subdirectories + return f"{ImageManager.THUMBNAIL_PREFIX}{filename}" + + except Exception as e: + logger.error(f"Error encoding path {original_path}: {e}") + return f"{ImageManager.THUMBNAIL_PREFIX}{os.path.basename(original_path)}" + + @staticmethod + def decode_thumbnail_name(thumbnail_name: str) -> Tuple[List[str], str]: + """ + Decode a thumbnail filename back into path components and original filename + Returns (path_components, filename) + """ + if not thumbnail_name.startswith(ImageManager.THUMBNAIL_PREFIX): + return [], thumbnail_name + + # Remove prefix + name_without_prefix = thumbnail_name[len(ImageManager.THUMBNAIL_PREFIX):] + + # Split into directory part and filename + parts = name_without_prefix.split(ImageManager.PATH_SEPARATOR) + + if len(parts) == 2: + # Has directory information + dir_part, filename = parts + if dir_part.startswith(ImageManager.SUBFOLDER_PREFIX): + # Extract directory levels + dir_levels = dir_part[len(ImageManager.SUBFOLDER_PREFIX):].split(ImageManager.LEVEL_SEPARATOR) + # Remove empty strings + dir_levels = [level for level in dir_levels if level] + return dir_levels, filename + + # No directory information + return [], parts[-1] + + @staticmethod + def get_original_path(thumbnail_name: str, base_path: str) -> str: + """Convert a thumbnail name back to its original file path""" + try: + dir_levels, filename = ImageManager.decode_thumbnail_name(thumbnail_name) + + if dir_levels: + # Reconstruct path with proper system separators + subpath = os.path.join(*dir_levels) if dir_levels else "" + return os.path.normpath(os.path.join(base_path, subpath, filename)) + else: + return os.path.normpath(os.path.join(base_path, filename)) + except Exception as e: + logger.error(f"Error decoding thumbnail name {thumbnail_name}: {e}") + return os.path.join(base_path, thumbnail_name.replace(ImageManager.THUMBNAIL_PREFIX, "")) + + @staticmethod + def normalize_path(path: str) -> str: + """Normalize path separators to system format""" + return os.path.normpath(path.replace('\\', os.sep).replace('/', os.sep)) + + @staticmethod + def get_relative_path(file_path: str, base_path: str) -> str: + """Get the relative path preserving all folder levels""" + try: + return os.path.relpath(file_path, base_path) + except ValueError: + # Handle case where paths are on different drives + return file_path + + @staticmethod + def get_image_files(folder_path: str, include_subfolders: bool, filter_type: str) -> List[str]: + """Get list of image files with complete path hierarchy""" + valid_exts = ImageManager.VALID_EXTENSIONS.get(filter_type.lower(), ImageManager.VALID_EXTENSIONS["none"]) + found_files = [] + + # Normalize the base folder path + folder_path = ImageManager.normalize_path(folder_path) + + def is_valid_image(filename: str) -> bool: + return any(filename.lower().endswith(ext) for ext in valid_exts) + + try: + if include_subfolders: + for root, _, filenames in os.walk(folder_path): + for filename in filenames: + if is_valid_image(filename): + full_path = os.path.join(root, filename) + # Store absolute paths for consistent handling + found_files.append(os.path.abspath(full_path)) + else: + with os.scandir(folder_path) as entries: + for entry in entries: + if entry.is_file() and is_valid_image(entry.name): + found_files.append(os.path.abspath(entry.path)) + + return found_files + except Exception as e: + logger.error(f"Error getting image files from {folder_path}: {e}") + return [] + + @staticmethod + def encode_path_to_filename(original_path: str, base_path: str) -> str: + """Convert a full file path to an encoded thumbnail filename preserving complete hierarchy""" + try: + # Normalize both paths + original_path = ImageManager.normalize_path(original_path) + base_path = ImageManager.normalize_path(base_path) + + # Get relative path from base_path + rel_path = ImageManager.get_relative_path(original_path, base_path) + + # Split path into components + path_parts = rel_path.split(os.sep) + + if len(path_parts) > 1: + # Join all directory parts with PATH_SEPARATOR + dirs = ImageManager.PATH_SEPARATOR.join(path_parts[:-1]) + filename = path_parts[-1] + return f"{ImageManager.THUMBNAIL_PREFIX}{dirs}{ImageManager.PATH_SEPARATOR}{filename}" + else: + # No subdirectories + return f"{ImageManager.THUMBNAIL_PREFIX}{rel_path}" + except Exception as e: + logger.error(f"Error encoding path {original_path}: {e}") + return f"{ImageManager.THUMBNAIL_PREFIX}{os.path.basename(original_path)}" + + @staticmethod + def decode_thumbnail_name(thumbnail_name: str) -> Tuple[List[str], str]: + """Decode a thumbnail filename back into path components and original filename""" + if not thumbnail_name.startswith(ImageManager.THUMBNAIL_PREFIX): + return [], thumbnail_name + + # Remove prefix + name_without_prefix = thumbnail_name[len(ImageManager.THUMBNAIL_PREFIX):] + + # Split by path separator and handle escaped separators + parts = name_without_prefix.split(ImageManager.PATH_SEPARATOR) + + # Last part is the filename, everything else is path components + return parts[:-1], parts[-1] + + @staticmethod + def get_original_path(thumbnail_name: str, base_path: str) -> str: + """Convert a thumbnail name back to its original file path""" + try: + path_parts, filename = ImageManager.decode_thumbnail_name(thumbnail_name) + + # Normalize base path + base_path = ImageManager.normalize_path(base_path) + + if path_parts: + # Reconstruct path with proper system separators + subpath = os.path.join(*path_parts) + return os.path.normpath(os.path.join(base_path, subpath, filename)) + else: + return os.path.normpath(os.path.join(base_path, filename)) + except Exception as e: + logger.error(f"Error decoding thumbnail name {thumbnail_name}: {e}") + return os.path.join(base_path, thumbnail_name.replace(ImageManager.THUMBNAIL_PREFIX, "")) + + @staticmethod + def create_thumbnails(folder_path: str, include_subfolders: bool = True, + filter_type: str = "none", sort_method: str = "alphabetical", + start_index: int = 0, max_images: Optional[int] = None) -> Tuple[bool, str, List[str], Dict[str, int]]: + try: + input_dir = folder_paths.get_input_directory() + thumbnail_paths = [] + image_order = {} # Track image order + + # Normalize paths + if not os.path.isabs(folder_path): + folder_path = os.path.abspath(os.path.join(folder_paths.get_input_directory(), folder_path)) + folder_path = ImageManager.normalize_path(folder_path) + + if not os.path.exists(folder_path): + return False, f"Path not found: {folder_path}", [], {} + + # Get and filter files + files = ImageManager.get_image_files(folder_path, include_subfolders, filter_type) + if not files: + return False, "No valid images found in the specified path", [], {} + + # Sort files + files = sorted(files, key=numerical_sort_key if sort_method == "numerical" + else os.path.getctime if sort_method == "date_created" + else os.path.getmtime if sort_method == "date_modified" + else str) + + # Clean up existing thumbnails + for f in os.listdir(input_dir): + if f.startswith(ImageManager.THUMBNAIL_PREFIX): + try: + os.remove(os.path.join(input_dir, f)) + except Exception as e: + logger.warning(f"Error removing old thumbnail {f}: {e}") + + # Apply index and count limits + start_idx = min(max(0, start_index), len(files)) + end_idx = len(files) if max_images is None else min(start_idx + max_images, len(files)) + selected_files = files[start_idx:end_idx] + + # Create thumbnails with encoded paths - only for selected range + for idx, file_path in enumerate(selected_files, start=start_idx): + try: + thumb_name = ImageManager.encode_path_to_filename(file_path, folder_path) + thumb_path = os.path.join(input_dir, thumb_name) + + with Image.open(file_path) as img: + img = ImageOps.exif_transpose(img) + + if img.mode in ('RGBA', 'LA'): + background = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'RGBA': + background.paste(img, mask=img.split()[3]) + else: + background.paste(img, mask=img.split()[1]) + img = background + elif img.mode not in ('RGB', 'L'): + img = img.convert('RGB') + + img.thumbnail(ImageManager.THUMBNAIL_SIZE, Image.Resampling.LANCZOS) + img.save(thumb_path, "JPEG", quality=70, optimize=True) + + thumbnail_paths.append(thumb_name) + image_order[thumb_name] = idx # Store image index + logger.info(f"Created thumbnail: {thumb_name} for {file_path}") + + except Exception as e: + logger.warning(f"Error creating thumbnail for {file_path}: {e}") + continue + + if not thumbnail_paths: + return False, "Failed to create any thumbnails", [], {} + + return True, f"Created {len(thumbnail_paths)} thumbnails", thumbnail_paths, image_order + + except Exception as e: + logger.error(f"Thumbnail creation failed: {str(e)}") + return False, f"Thumbnail creation failed: {str(e)}", [], {} + + @staticmethod + def backup_input_folder() -> Tuple[bool, str]: + try: + input_dir = folder_paths.get_input_directory() + backup_dir = os.path.join(os.path.dirname(input_dir), "input_backup") + + # Create backup directory if it doesn't exist + if not os.path.exists(backup_dir): + os.makedirs(backup_dir) + + # First, remove all thumbnail files + for file in os.listdir(input_dir): + if file.startswith(ImageManager.THUMBNAIL_PREFIX): + os.remove(os.path.join(input_dir, file)) + + # Copy remaining files from input to backup + for file in os.listdir(input_dir): + file_path = os.path.join(input_dir, file) + if os.path.isfile(file_path): + shutil.copy2(file_path, backup_dir) + + # Clear input directory + for file in os.listdir(input_dir): + file_path = os.path.join(input_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + + return True, "Input folder backed up successfully" + except Exception as e: + logger.error(f"Backup failed: {str(e)}") + return False, f"Backup failed: {str(e)}" + + @staticmethod + def restore_input_folder() -> Tuple[bool, str]: + try: + input_dir = folder_paths.get_input_directory() + backup_dir = os.path.join(os.path.dirname(input_dir), "input_backup") + + if not os.path.exists(backup_dir): + return False, "Backup directory not found" + + # Clear thumbnails first + for file in os.listdir(input_dir): + if file.startswith(ImageManager.THUMBNAIL_PREFIX): + os.remove(os.path.join(input_dir, file)) + + # Restore original files + for file in os.listdir(backup_dir): + shutil.copy2(os.path.join(backup_dir, file), input_dir) + + return True, "Input folder restored successfully" + except Exception as e: + logger.error(f"Restore failed: {str(e)}") + return False, f"Restore failed: {str(e)}" + + @staticmethod + def sort_files(files: List[str], sort_method: str) -> List[str]: + """Sort files based on selected method""" + if sort_method == "numerical": + return sorted(files, key=numerical_sort_key) + elif sort_method == "date_created": + return sorted(files, key=os.path.getctime) + elif sort_method == "date_modified": + return sorted(files, key=os.path.getmtime) + else: # alphabetical + return sorted(files) + +class IFLoadImagess: + def __init__(self): + self.path_cache = {} # Cache for path mapping + + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + # Count available thumbnails + available_images = len([f for f in os.listdir(input_dir) + if f.startswith(ImageManager.THUMBNAIL_PREFIX)]) + available_images = max(1, available_images) # Ensure at least 1 + + files = [f for f in os.listdir(input_dir) + if f.startswith(ImageManager.THUMBNAIL_PREFIX)] + + return { + "required": { + "image": (sorted(files), {"image_upload": True}), + "input_path": ("STRING", {"default": ""}), + "start_index": ("INT", {"default": 0, "min": 0, "max": 9999}), + "stop_index": ("INT", {"default": 10, "min": 1, "max": 9999}), # Changed to stop_index + "load_limit": (["10", "100", "1000", "10000", "100000"], {"default": "1000"}), + "image_selected": ("BOOLEAN", {"default": False}), + "available_image_count": ("INT", { + "default": available_images, + "min": 0, + "max": 99999, + "readonly": True + }), + "include_subfolders": ("BOOLEAN", {"default": True}), + "sort_method": (["alphabetical", "numerical", "date_created", "date_modified"],), + "filter_type": (["none", "png", "jpg", "jpeg", "webp", "gif", "bmp"],), + } + } + + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "STRING", "STRING", "INT") + RETURN_NAMES = ("images", "masks", "image_paths", "filenames", "count_str", "count_int") + OUTPUT_IS_LIST = (True, True, True, True, True, True) + FUNCTION = "load_images" + CATEGORY = "ImpactFrames💥🎞️" + + @classmethod + def IS_CHANGED(cls, image, input_path="", start_index=0, stop_index=0, max_images=1, + include_subfolders=True, sort_method="numerical", image_selected=False, + filter_type="none", image_name="", unique_id=None, load_limit="1000", available_image_count=0): + """ + Properly handle all input parameters and return NaN to force updates + This matches the input parameters from INPUT_TYPES + """ + try: + # If we have a specific image selected, use its path + if image and not image.startswith("thb_"): + image_path = folder_paths.get_annotated_filepath(image) + if image_path: + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + # For directory-based loads, return NaN to force updates + return float("NaN") + except Exception as e: + logging.warning(f"Error in IS_CHANGED: {e}") + return float("NaN") + + def load_images(self, image="", input_path="", start_index=0, stop_index=10, + load_limit="1000", image_selected=False, available_image_count=0, + include_subfolders=True, sort_method="numerical", + filter_type="none", image_name="", unique_id=None): + try: + # Process input path + abs_path = os.path.abspath(input_path if os.path.isabs(input_path) + else os.path.join(folder_paths.get_input_directory(), input_path)) + + # Get all valid images first + all_files = ImageManager.get_image_files(abs_path, include_subfolders, filter_type) + if not all_files: + logger.warning(f"No valid images found in {abs_path}") + img_tensor, mask = self.load_placeholder() + return ([img_tensor], [mask], [""], [""], ["0/0"], [0]) + + # Sort files + all_files = ImageManager.sort_files(all_files, sort_method) + total_files = len(all_files) + + # Validate indices + start_index = min(max(0, start_index), total_files) + stop_index = min(max(start_index + 1, stop_index), total_files) + num_images = min(stop_index - start_index, int(load_limit)) + + # Generate thumbnails + success, _, all_thumbnails, image_order = ImageManager.create_thumbnails( + abs_path, include_subfolders, filter_type, sort_method, + start_index=start_index, + max_images=num_images + ) + + # Handle image selection + if image_selected and image in image_order: + start_index = image_order[image] + num_images = 1 + + # Create path mapping + self.path_cache = { + thumb: orig for thumb, orig in zip(all_thumbnails, all_files[start_index:start_index + num_images]) + } + + # Process selected range + selected_files = all_files[start_index:start_index + num_images] + selected_thumbnails = all_thumbnails[:num_images] + + # Process selected files + images = [] + masks = [] + paths = [] + filenames = [] + count_strs = [] + count_ints = [] + + for idx, (file_path, thumb_name) in enumerate(zip(selected_files, selected_thumbnails)): + try: + with Image.open(file_path) as img: + img = ImageOps.exif_transpose(img) + + if img.mode == 'I': + img = img.point(lambda i: i * (1 / 255)) + image = img.convert('RGB') + + image_array = np.array(image).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_array)[None,] + + if 'A' in img.getbands(): + mask = np.array(img.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((image_array.shape[0], image_array.shape[1]), + dtype=torch.float32, device="cpu") + + images.append(image_tensor) + masks.append(mask.unsqueeze(0)) + paths.append(file_path) + filenames.append(os.path.basename(file_path)) + count_str = f"{start_index + idx + 1}/{total_files}" # Update count to show global position + count_strs.append(count_str) + count_ints.append(start_index + idx + 1) + + except Exception as e: + logger.error(f"Error processing image {file_path}: {e}") + continue + + if not images: + img_tensor, mask = self.load_placeholder() + return ([img_tensor], [mask], [""], [""], ["0/0"], [0]) + + ui_data = { + "images": all_thumbnails, + "current_thumbnails": selected_thumbnails, + "total_images": total_files, + "path_mapping": self.path_cache, + "available_image_count": total_files, + "image_order": image_order, + "start_index": start_index, + "stop_index": start_index + num_images + } + + return { + "ui": {"values": ui_data}, + "result": (images, masks, paths, filenames, count_strs, count_ints) + } + + except Exception as e: + logger.error(f"Error in load_images: {e}", exc_info=True) + img_tensor, mask = self.load_placeholder() + return ([img_tensor], [mask], [""], [""], ["error"], [0]) + + def load_placeholder(self): + """Creates and returns a placeholder image tensor and mask""" + img = Image.new('RGB', (512, 512), color=(73, 109, 137)) + image_array = np.array(img).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_array)[None,] + mask = torch.zeros((1, image_array.shape[0], image_array.shape[1]), + dtype=torch.float32, device="cpu") + return image_tensor, mask + + def process_single_image(self, image_path: str): + """Process a single image and return appropriate outputs""" + try: + img = Image.open(image_path) + img = ImageOps.exif_transpose(img) + + if img.mode == 'I': + img = img.point(lambda i: i * (1 / 255)) + + image = img.convert("RGB") + image_array = np.array(image).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_array)[None,] + + if 'A' in img.getbands(): + mask = np.array(img.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((image_array.shape[0], image_array.shape[1]), + dtype=torch.float32, device="cpu") + + filename = os.path.basename(image_path) + return ([image_tensor], [mask.unsqueeze(0)], [image_path], [filename], ["1/1"], [1]) + + except Exception as e: + logger.error(f"Error processing single image {image_path}: {e}") + img_tensor, mask = self.load_placeholder() + return ([img_tensor], [mask], [""], [""], ["error"], [0]) + +@PromptServer.instance.routes.post("/ifai/backup_input") +async def backup_input_folder(request): + try: + success, message = ImageManager.backup_input_folder() + return web.json_response({ + "success": success, + "message": message + }) + except Exception as e: + logger.error(f"Error in backup_input_folder route: {str(e)}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + +@PromptServer.instance.routes.post("/ifai/restore_input") +async def restore_input_folder(request): + try: + success, message = ImageManager.restore_input_folder() + return web.json_response({ + "success": success, + "message": message + }) + except Exception as e: + logger.error(f"Error in restore_input_folder route: {str(e)}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + +@PromptServer.instance.routes.post("/ifai/refresh_previews") +async def refresh_previews(request): + try: + data = await request.json() + if not data.get("input_path"): + raise ValueError("No input path provided") + + # Extract parameters + load_limit = int(data.get("load_limit", "1000")) + start_index = int(data.get("start_index", 0)) + stop_index = int(data.get("stop_index", 10)) + include_subfolders = data.get("include_subfolders", True) + filter_type = data.get("filter_type", "none") + sort_method = data.get("sort_method", "alphabetical") + + # Get files + if not os.path.isabs(data["input_path"]): + abs_path = os.path.abspath(os.path.join(folder_paths.get_input_directory(), data["input_path"])) + else: + abs_path = data["input_path"] + + # Get all files and sort them + all_files = ImageManager.get_image_files(abs_path, include_subfolders, filter_type) + if not all_files: + return web.json_response({ + "success": False, + "error": "No valid images found" + }) + + all_files = sorted(all_files, + key=numerical_sort_key if sort_method == "numerical" + else os.path.getctime if sort_method == "date_created" + else os.path.getmtime if sort_method == "date_modified" + else str) + + total_available = len(all_files) + + # Validate and adjust indices + start_index = min(max(0, start_index), total_available) + stop_index = min(max(start_index + 1, stop_index), total_available) + + # Calculate how many images to actually load + num_images = min(stop_index - start_index, load_limit) + + # Create thumbnails only for the selected range + success, message, thumbnails, image_order = ImageManager.create_thumbnails( + data["input_path"], + include_subfolders=include_subfolders, + filter_type=filter_type, + sort_method=sort_method, + start_index=start_index, + max_images=num_images # Only create thumbnails for the range we want + ) + + return web.json_response({ + "success": success, + "message": message, + "thumbnails": thumbnails, + "total_images": total_available, + "visible_images": len(thumbnails), + "start_index": start_index, + "stop_index": stop_index, + "image_order": image_order + }) + + except Exception as e: + logger.error(f"Error in refresh_previews route: {str(e)}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + +# Add route for widget refresh +@PromptServer.instance.routes.post("/ifai/refresh_widgets") +async def refresh_widgets(request): + try: + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = sorted(files) + + return web.json_response({ + "success": True, + "files": files + }) + except Exception as e: + logger.error(f"Error in refresh_widgets route: {str(e)}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + +# Register node class +NODE_CLASS_MAPPINGS = { + "IF_LoadImagesS": IFLoadImagess +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "IF_LoadImagesS": "IF Load Images S 🖼️" +} \ No newline at end of file diff --git a/IFPromptMkrNode.py b/IFPromptMkrNode.py index 2b1628f..6977a99 100644 --- a/IFPromptMkrNode.py +++ b/IFPromptMkrNode.py @@ -1,269 +1,711 @@ -import json -import requests -import os -import sys -import textwrap - -# Add the ComfyUI directory to the Python path -comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) -sys.path.insert(0, comfy_path) - -try: - from server import PromptServer - from aiohttp import web - - @PromptServer.instance.routes.post("/IF_PromptMkr/get_models") - async def get_models_endpoint(request): - data = await request.json() - engine = data.get("engine") - base_ip = data.get("base_ip") - port = data.get("port") - - node = IFPrompt2Prompt() - models = node.get_models(engine, base_ip, port) - return web.json_response(models) -except AttributeError: - print("PromptServer.instance not available. Skipping route decoration.") - async def get_models_endpoint(request): - # Fallback implementation - return web.json_response({"error": "PromptServer.instance not available"}) - -import tempfile - -class IFPrompt2Prompt: - RETURN_TYPES = ("STRING", "STRING", "STRING",) - RETURN_NAMES = ("Question", "Response", "Negative",) - FUNCTION = "sample" - OUTPUT_NODE = False - CATEGORY = "ImpactFrames💥🎞️" - - @classmethod - def INPUT_TYPES(cls): - node = cls() - return { - "required": { - "input_prompt": ("STRING", {"multiline": True, "default": "Ancient mega-structure, small lone figure in the foreground"}), - "base_ip": ("STRING", {"default": node.base_ip}), - "port": ("STRING", {"default": node.port}), - "engine": (["ollama", "openai", "anthropic"], {"default": node.engine}), - #"selected_model": (node.get_models("node.engine", node.base_ip, node.port), {}), - "selected_model": ((), {}), - "profile": ([name for name in node.profiles.keys()], {"default": node.profile}), - "embellish_prompt": ([name for name in node.embellish_prompts.keys()], {}), - "style_prompt": ([name for name in node.style_prompts.keys()], {}), - "neg_prompt": ([name for name in node.neg_prompts.keys()], {}), - "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.1}), - }, - "optional": { - "max_tokens": ("INT", {"default": 256, "min": 1, "max": 8192}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature"}), - "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps_Model", "label_off": "Unloads_Model"}), - }, - "hidden": { - "model": ("STRING", {"default": ""}), - }, - } - @classmethod - def IS_CHANGED(cls, engine, base_ip, port, profile, keep_alive): - node = cls() - if engine != node.engine or base_ip != node.base_ip or port != node.port or node.selected_model != node.get_models(engine, base_ip, port) or keep_alive != node.keep_alive: - node.engine = engine - node.base_ip = base_ip - node.port = port - node.selected_model = node.get_models(engine, base_ip, port) - node.profile = profile - node.keep_alive = keep_alive - return True - return False - - def __init__(self): - self.base_ip = "localhost" - self.port = "11434" - self.engine = "ollama" - self.selected_model = "" - self.profile = "IF_PromptMKR" - self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.presets_dir = os.path.join(os.path.dirname(__file__), "presets") - self.profiles_file = os.path.join(self.presets_dir, "profiles.json") - self.profiles = self.load_presets(self.profiles_file) - self.neg_prompts_file = os.path.join(self.presets_dir, "neg_prompts.json") - self.embellish_prompts_file = os.path.join(self.presets_dir, "embellishments.json") - self.style_prompts_file = os.path.join(self.presets_dir, "style_prompts.json") - self.neg_prompts = self.load_presets(self.neg_prompts_file) - self.embellish_prompts = self.load_presets(self.embellish_prompts_file) - self.style_prompts = self.load_presets(self.style_prompts_file) - - def load_presets(self, file_path): - with open(file_path, 'r') as f: - presets = json.load(f) - return presets - - def get_api_key(self, api_key_name, engine): - if engine != "ollama": - api_key = os.getenv(api_key_name) - if api_key: - return api_key - else: - print(f'you are using ollama as the engine, no api key is required') - - def get_models(self, engine, base_ip, port): - if engine == "ollama": - api_url = f'http://{base_ip}:{port}/api/tags' - try: - response = requests.get(api_url) - response.raise_for_status() - models = [model['name'] for model in response.json().get('models', [])] - return models - except Exception as e: - print(f"Failed to fetch models from Ollama: {e}") - return [] - elif engine == "anthropic": - return ["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"] - elif engine == "openai": - return ["gpt-4-0125-preview", "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4-1106-vision-preview", "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106"] - else: - print(f"Unsupported engine - {engine}") - return [] - - def sample(self, input_prompt, engine, base_ip, port, selected_model, embellish_prompt, style_prompt, neg_prompt, temperature, max_tokens, seed, random, keep_alive, profile): - embellish_content = self.embellish_prompts.get(embellish_prompt, "") - style_content = self.style_prompts.get(style_prompt, "") - neg_content = self.neg_prompts.get(neg_prompt, "") - profile_selected = self.profiles.get(profile, "") - - if engine == "anthropic": - data = { - 'model': selected_model, - 'system': profile_selected , - 'messages': [ - {"role": "user", "content": input_prompt} - ], - 'temperature': temperature, - 'max_tokens': max_tokens - } - elif engine == "openai": - if random == True: - data = { - 'model': selected_model, - 'messages': [ - {"role": "system", "content": profile_selected }, - {"role": "user", "content": input_prompt} - ], - 'temperature': temperature, - 'seed': seed, - 'max_tokens': max_tokens - } - else: - data = { - 'model': selected_model, - 'messages': [ - {"role": "system", "content": profile_selected }, - {"role": "user", "content": input_prompt} - ], - 'temperature': temperature, - 'max_tokens': max_tokens - } - else: - if random == True: - data = { - "model": selected_model, - "system": profile_selected , - "prompt": input_prompt, - "stream": False, - "options": { - "temperature": temperature, - "seed": seed, - "num_ctx": max_tokens, - }, - "keep_alive": -1 if keep_alive else 0, - } - else: - data = { - "model": selected_model, - "system": profile_selected , - "prompt": input_prompt, - "stream": False, - "options": { - "temperature": temperature, - "seed": seed, - "num_ctx": max_tokens, - }, - "keep_alive": -1 if keep_alive else 0, - } - - generated_text = self.send_request(engine, base_ip, port, data, headers={"Content-Type": "application/json"}) - - if generated_text: - combined_prompt = f"{embellish_content} {generated_text} {style_content}" - return input_prompt, combined_prompt, neg_content - else: - return None, None, None - - def send_request(self, engine, base_ip, port, data, headers): - if engine == "ollama": - api_url = f'http://{base_ip}:{port}/api/generate' - response = requests.post(api_url, headers=headers, json=data) - if response.status_code == 200: - response_data = response.json() - prompt_response = response_data.get('response', 'No response text found') - - # Ensure there is a response to construct the full description - if prompt_response != 'No response text found': - return prompt_response - else: - return "No valid response generated for the image." - else: - print(f"Failed to fetch response, status code: {response.status_code}") - return "Failed to fetch response from Ollama." - elif engine == "anthropic": - anthropic_api_key = self.get_api_key("ANTHROPIC_API_KEY", engine) - try: - base_url = 'https://api.anthropic.com/v1/messages' - anthropic_headers = { - "x-api-key": anthropic_api_key, - "anthropic-version": "2023-06-01", - "Content-Type": "application/json" - } - response = requests.post(base_url, headers=anthropic_headers, json=data) - if response.status_code == 200: - messages = response.json().get('content', []) - generated_text = ''.join([msg.get('text', '') for msg in messages if msg.get('type') == 'text']) - return generated_text - else: - print(f"Error: Request failed with status code {response.status_code}, Response: {response.text}") - return None - except Exception as e: - print(f"Error: Anthropic request failed - {e}") - return None - elif engine == "openai": - openai_api_key = self.get_api_key("OPENAI_API_KEY", engine) - try: - base_url = 'https://api.openai.com/v1/chat/completions' - openai_headers = { - "Authorization": f"Bearer {openai_api_key}", - "Content-Type": "application/json" - } - response = requests.post(base_url, headers=openai_headers, json=data) - if response.status_code == 200: - response_data = response.json() - print("Debug Response:", response_data) - choices = response_data.get('choices', []) - if choices: - choice = choices[0] - messages = choice.get('message', {'content': ''}) - generated_text = messages.get('content', '') - return generated_text - else: - print("No choices found in response") - return None - else: - print(f"Error: Request failed with status code {response.status_code}, Response: {response.text}") - return None - except Exception as e: - print(f"Error: OpenAI request failed - {e}") - return None, None, None - - -NODE_CLASS_MAPPINGS = {"IF_PromptMkr": IFPrompt2Prompt} -NODE_DISPLAY_NAME_MAPPINGS = {"IF_PromptMkr": "IF Prompt to Prompt💬"} +#IFPromptMkrNode.py +import os +import sys +import json +import torch +import asyncio +import requests +from PIL import Image +from io import BytesIO +from typing import List, Dict, Any, Optional, Union, Tuple +import folder_paths +from .omost import omost_function +from .send_request import send_request +from .utils import ( + get_api_key, + get_models, + process_images_for_comfy, + process_mask, + clean_text, + load_placeholder_image, + validate_models, +) + +# Add ComfyUI directory to path +comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.insert(0, comfy_path) + +# Set up logging +import logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +try: + from server import PromptServer + from aiohttp import web + + @PromptServer.instance.routes.post("/IF_PromptMkr/get_llm_models") + async def get_llm_models_endpoint(request): + try: + data = await request.json() + llm_provider = data.get("llm_provider") + engine = llm_provider + base_ip = data.get("base_ip") + port = data.get("port") + external_api_key = data.get("external_api_key") + + if external_api_key: + api_key = external_api_key + else: + api_key_name = f"{llm_provider.upper()}_API_KEY" + try: + api_key = get_api_key(api_key_name, engine) + except ValueError: + api_key = None + + node = IFPrompt2Prompt() + models = node.get_models(engine, base_ip, port, api_key) + return web.json_response(models) + + except Exception as e: + print(f"Error in get_llm_models_endpoint: {str(e)}") + return web.json_response([], status=500) + + @PromptServer.instance.routes.post("/IF_PromptMkr/add_routes") + async def add_routes_endpoint(request): + return web.json_response({"status": "success"}) + +except AttributeError: + print("PromptServer.instance not available. Skipping route decoration for IF_PromptMkr.") + +class IFPrompt2Prompt: + def __init__(self): + self.strategies = "normal" + # Initialize paths and load presets + self.base_path = folder_paths.base_path + self.presets_dir = os.path.join(self.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "presets") + + # Load preset configurations + self.profiles = self.load_presets(os.path.join(self.presets_dir, "profiles.json")) + self.neg_prompts = self.load_presets(os.path.join(self.presets_dir, "neg_prompts.json")) + self.embellish_prompts = self.load_presets(os.path.join(self.presets_dir, "embellishments.json")) + self.style_prompts = self.load_presets(os.path.join(self.presets_dir, "style_prompts.json")) + self.stop_strings = self.load_presets(os.path.join(self.presets_dir, "stop_strings.json")) + + # Initialize placeholder image path + self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "placeholder.png") + + # Default values + + self.base_ip = "localhost" + self.port = "11434" + self.engine = "xai" + self.selected_model = "" + self.profile = "IF_PromptMKR_IMG" + self.messages = [] + self.keep_alive = False + self.seed = 94687328150 + self.history_steps = 10 + self.external_api_key = "" + self.preset = "Default" + self.precision = "fp16" + self.attention = "sdpa" + self.Omni = None + self.mask = None + self.aspect_ratio = "1:1" + self.keep_alive = False + self.clear_history = False + self.random = False + self.max_tokens = 2048 + self.temperature = 0.7 + self.top_k = 40 + self.top_p = 0.9 + self.repeat_penalty = 1.1 + self.stop = None + self.batch_count = 4 + + @classmethod + def INPUT_TYPES(cls): + node = cls() + return { + "required": { + "images": ("IMAGE", {"list": True}), # Primary image input + "llm_provider": (["xai","llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "transformers"], {}), + "llm_model": ((), {}), + "base_ip": ("STRING", {"default": "localhost"}), + "port": ("STRING", {"default": "11434"}), + "user_prompt": ("STRING", {"multiline": True}), + }, + "optional": { + "strategy": (["normal", "omost", "create", "edit", "variations"], {"default": "normal"}), + "mask": ("MASK", {}), + "prime_directives": ("STRING", {"forceInput": True, "tooltip": "The system prompt for the LLM."}), + "profiles": (["None"] + list(cls().profiles.keys()), {"default": "None", "tooltip": "The pre-defined system_prompt from the json profile file on the presets folder you can edit or make your own will be listed here."}), + "embellish_prompt": (list(cls().embellish_prompts.keys()), {"tooltip": "The pre-defined embellishment from the json embellishments file on the presets folder you can edit or make your own will be listed here."}), + "style_prompt": (list(cls().style_prompts.keys()), {"tooltip": "The pre-defined style from the json style_prompts file on the presets folder you can edit or make your own will be listed here."}), + "neg_prompt": (list(cls().neg_prompts.keys()), {"tooltip": "The pre-defined negative prompt from the json neg_prompts file on the presets folder you can edit or make your own will be listed here."}), + "stop_string": (list(cls().stop_strings.keys()), {"tooltip": "Specifies a string at which text generation should stop."}), + "max_tokens": ("INT", {"default": 2048, "min": 1, "max": 8192, "tooltip": "Maximum number of tokens to generate in the response."}), + "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature", "tooltip": "Toggles between using a fixed seed or temperature-based randomness."}), + "seed": ("INT", {"default": 0, "tooltip": "Random seed for reproducible outputs."}), + "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "tooltip": "Controls randomness in output generation. Higher values increase creativity but may reduce coherence."}), + "top_k": ("INT", {"default": 40, "tooltip": "Limits the next token selection to the K most likely tokens."}), + "top_p": ("FLOAT", {"default": 0.9, "tooltip": "Cumulative probability cutoff for token selection."}), + "repeat_penalty": ("FLOAT", {"default": 1.1, "tooltip": "Penalizes repetition in generated text."}), + "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps Model on Memory", "label_off": "Unloads Model from Memory", "tooltip": "Determines whether to keep the model loaded in memory between calls."}), + "clear_history": ("BOOLEAN", {"default": False, "label_on": "Clear History", "label_off": "Keep History", "tooltip": "Determines whether to clear the history between calls."}), + "history_steps": ("INT", {"default": 10, "tooltip": "Number of steps to keep in history."}), + "aspect_ratio": (["1:1", "16:9", "4:5", "3:4", "5:4", "9:16"], {"default": "1:1", "tooltip": "Aspect ratio for the generated images."}), + "batch_count": ("INT", {"default": 4, "tooltip": "Number of images to generate. only for create, edit and variations strategies."}), + "external_api_key": ("STRING", {"default": "", "tooltip": "If this is not empty, it will be used instead of the API key from the .env file. Make sure it is empty to use the .env file."}), + "precision": (["fp16", "fp32", "bf16"], {"tooltip": "Select preccision on Transformer models."}), + "attention": (["sdpa", "flash_attention_2", "xformers"], {"tooltip": "Select attention mechanism on Transformer models."}), + "Omni": ("OMNI", {"default": None, "tooltip": "Additional input for the selected tool."}), + } + } + + RETURN_TYPES = ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK") + RETURN_NAMES = ("question", "response", "negative", "omni", "generated_images", "mask") + + FUNCTION = "process_image_wrapper" + OUTPUT_NODE = True + CATEGORY = "ImpactFrames💥🎞️" + + def get_models(self, engine, base_ip, port, api_key=None): + return get_models(engine, base_ip, port, api_key) + + def load_presets(self, file_path: str) -> Dict[str, Any]: + try: + with open(file_path, 'r') as f: + return json.load(f) + except Exception as e: + print(f"Error loading presets from {file_path}: {e}") + return {} + + def validate_outputs(self, outputs): + """Helper to validate output types match expectations""" + if len(outputs) != len(self.RETURN_TYPES): + raise ValueError( + f"Expected {len(self.RETURN_TYPES)} outputs, got {len(outputs)}" + ) + + for i, (output, expected_type) in enumerate(zip(outputs, self.RETURN_TYPES)): + if output is None and expected_type in ["IMAGE", "MASK"]: + raise ValueError( + f"Output {i} ({self.RETURN_NAMES[i]}) cannot be None for type {expected_type}" + ) + + async def generate_negative_prompts( + self, + prompt: str, + llm_provider: str, + llm_model: str, + base_ip: str, + port: str, + config: dict, + messages: list = None + ) -> List[str]: + """ + Generate negative prompts for the given input prompt. + + Args: + prompt: Input prompt text + llm_provider: LLM provider name + llm_model: Model name + base_ip: API base IP + port: API port + config: Dict containing generation parameters like seed, temperature etc + messages: Optional message history + + Returns: + List of generated negative prompts + """ + try: + if not prompt: + return [] + + # Get system message for negative prompts + neg_system_message = self.profiles.get("IF_NegativePromptEngineer", "") + + # Generate negative prompts + neg_response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=None, + llm_model=llm_model, + system_message=neg_system_message, + user_message=f"Generate negative prompts for:\n{prompt}", + messages=messages or [], + **config + ) + + if not neg_response: + return [] + + # Split into lines and clean up + neg_lines = [line.strip() for line in neg_response.split('\n') if line.strip()] + + # Match number of prompts + num_prompts = len(prompt.split('\n')) + if len(neg_lines) < num_prompts: + neg_lines.extend([neg_lines[-1] if neg_lines else ""] * (num_prompts - len(neg_lines))) + + return neg_lines[:num_prompts] + + except Exception as e: + logger.error(f"Error generating negative prompts: {str(e)}") + return ["Error generating negative prompt"] * num_prompts + + @classmethod + def IS_CHANGED(cls, **kwargs): + return float("NaN") + + async def process_image( + self, + llm_provider: str, + llm_model: str, + base_ip: str, + port: str, + user_prompt: str, + strategy: str = "normal", + images=None, + prime_directives: Optional[str] = None, + profiles: Optional[str] = None, + embellish_prompt: Optional[str] = None, + style_prompt: Optional[str] = None, + neg_prompt: Optional[str] = None, + stop_string: Optional[str] = None, + max_tokens: int = 2048, + seed: int = 0, + random: bool = False, + temperature: float = 0.8, + top_k: int = 40, + top_p: float = 0.9, + repeat_penalty: float = 1.1, + keep_alive: bool = False, + clear_history: bool = False, + history_steps: int = 10, + external_api_key: str = "", + precision: str = "fp16", + attention: str = "sdpa", + Omni: Optional[str] = None, + aspect_ratio: str = "1:1", + mask: Optional[torch.Tensor] = None, + batch_count: int = 4, + **kwargs + ) -> Union[str, Dict[str, Any]]: + try: + # Initialize variables at the start + formatted_response = None + generated_images = None + generated_masks = None + tool_output = None + + if external_api_key != "": + llm_api_key = external_api_key + else: + llm_api_key = get_api_key(f"{llm_provider.upper()}_API_KEY", llm_provider) + print(f"LLM API key: {llm_api_key[:5]}...") + + # Validate LLM model + validate_models(llm_model, llm_provider, "LLM", base_ip, port, llm_api_key) + + # Handle history + if clear_history: + self.messages = [] + elif history_steps > 0: + self.messages = self.messages[-history_steps:] + + messages = self.messages + + # Handle stop + if stop_string is None or stop_string == "None": + stop_content = None + else: + stop_content = self.stop_strings.get(stop_string, None) + stop = stop_content + + if llm_provider not in ["ollama", "llamacpp", "vllm", "lmstudio", "gemeni"]: + if llm_provider == "kobold": + stop = stop_content + \ + ["\n\n\n\n\n"] if stop_content else ["\n\n\n\n\n"] + elif llm_provider == "mistral": + stop = stop_content + \ + ["\n\n"] if stop_content else ["\n\n"] + else: + stop = stop_content if stop_content else None + + # Prepare embellishments and styles + embellish_content = self.embellish_prompts.get(embellish_prompt, "").strip() if embellish_prompt else "" + style_content = self.style_prompts.get(style_prompt, "").strip() if style_prompt else "" + neg_content = self.neg_prompts.get(neg_prompt, "").strip() if neg_prompt else "" + profile_content = self.profiles.get(profiles, "") + + # Prepare system prompt + if prime_directives is not None: + system_message_str = prime_directives + else: + system_message_str= json.dumps(profile_content) + + if strategy == "omost": + system_prompt = self.profiles.get("IF_Omost") + messages = [] + # Generate the text using LLM + llm_response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images, + llm_model=llm_model, + system_message=system_prompt, + user_message=user_prompt, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + tools=None, + tool_choice=None, + precision=precision, + attention=attention, + aspect_ratio=aspect_ratio, + strategy="omost", + batch_count=batch_count, + mask=mask, + ) + + # Pass the generated_text to omost_function + tool_args = { + "name": "omost_tool", + "description": "Analyzes images composition and generates a Canvas representation.", + "system_prompt": system_prompt, + "input": user_prompt, + "llm_response": llm_response, + "function_call": None, + "omni_input": Omni + } + + tool_result = await omost_function(tool_args) + + # Process the tool output + if "error" in tool_result: + llm_response = f"Error: {tool_result['error']}" + tool_output = None + else: + tool_output = tool_result.get("canvas_conditioning", "") + llm_response = f"{tool_output}" + cleaned_response = clean_text(llm_response) + + neg_content = self.neg_prompts.get(neg_prompt, "").strip() if neg_prompt else "" + + # Update message history if keeping alive + if keep_alive and cleaned_response: + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": cleaned_response}) + + return { + "Question": user_prompt, + "Response": cleaned_response, + "Negative": neg_content, + "Tool_Output": tool_output, + "Retrieved_Image": None, + "Mask": None + } + elif strategy in ["create", "edit", "variations"]: + resulting_images = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images, + llm_model=llm_model, + system_message=system_prompt, + user_message=user_prompt, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + tools=None, + tool_choice=None, + precision=precision, + attention=attention, + aspect_ratio=aspect_ratio, + strategy=strategy, + batch_count=batch_count, + mask=mask, + ) + if isinstance(resulting_images, dict) and "images" in resulting_images: + generated_images = resulting_images["images"] + generated_masks = None + else: + generated_images = None + generated_masks = None + + try: + if generated_images is not None: + if isinstance(generated_images, torch.Tensor): + # Ensure correct format (B, C, H, W) + image_tensor = generated_images.unsqueeze(0) if generated_images.dim() == 3 else generated_images + + # Create matching batch masks + batch_size = image_tensor.shape[0] + height = image_tensor.shape[2] + width = image_tensor.shape[3] + + # Create default masks + mask_tensor = torch.ones((batch_size, 1, height, width), + dtype=torch.float32, + device=image_tensor.device) + + if generated_masks is not None: + mask_tensor = process_mask(generated_masks, image_tensor) + else: + image_tensor, mask_tensor = process_images_for_comfy(generated_images, self.placeholder_image_path) + mask_tensor = process_mask(generated_masks, image_tensor) if generated_masks is not None else mask_tensor + else: + # No retrieved image - use original or placeholder + if images is not None and len(images) > 0: + image_tensor = images[0] if isinstance(images[0], torch.Tensor) else process_images_for_comfy(images, self.placeholder_image_path)[0] + mask_tensor = torch.ones_like(image_tensor[:1]) # Create mask with same spatial dimensions + else: + image_tensor, mask_tensor = load_placeholder_image(self.placeholder_image_path) + + return { + "Question": user_prompt, + "Response": f"{strategy} image has been successfully generated.", + "Negative": neg_content, + "Tool_Output": None, + "Retrieved_Image": image_tensor, + "Mask": mask_tensor + } + + except Exception as e: + print(f"Error in process_image: {str(e)}") + image_tensor, mask_tensor = load_placeholder_image(self.placeholder_image_path) + return { + "Question": user_prompt, + "Response": f"Error: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": image_tensor, + "Mask": mask_tensor + } + elif strategy == "normal": + try: + formatted_responses = [] + final_prompts = [] + final_negative_prompts = [] + + # Handle images as they come from ComfyUI - no extra processing needed + current_images = images if images is not None else None + + # If mask provided, ensure it matches image dimensions + if mask is not None: + mask_tensor = process_mask(mask, current_images) + else: + # Create default mask if needed + if current_images is not None: + mask_tensor = torch.ones((current_images.shape[0], 1, current_images.shape[2], current_images.shape[3]), + dtype=torch.float32, + device=current_images.device) + else: + _, mask_tensor = load_placeholder_image(self.placeholder_image_path) + + # Iterate over batches + for batch_idx in range(batch_count): + try: + response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=current_images, # Pass images directly + llm_model=llm_model, + system_message=system_message_str, + user_message=user_prompt, + messages=messages, + seed=seed + batch_idx if seed != 0 else seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key, + precision=precision, + attention=attention, + aspect_ratio=aspect_ratio, + strategy="normal", + batch_count=1, + mask=mask_tensor, + ) + + if not response: + raise ValueError("No response received from LLM API") + + # Clean and process response + cleaned_response = clean_text(response) + final_prompts.append(cleaned_response) + + # Handle negative prompts + if neg_prompt == "AI_Fill": + negative_prompt = await self.generate_negative_prompts( + prompt=cleaned_response, + llm_provider=llm_provider, + llm_model=llm_model, + base_ip=base_ip, + port=port, + config={ + "seed": seed + batch_idx if seed != 0 else seed, + "temperature": temperature, + "max_tokens": max_tokens, + "random": random, + "top_k": top_k, + "top_p": top_p, + "repeat_penalty": repeat_penalty + }, + messages=messages + ) + final_negative_prompts.append(negative_prompt[0] if negative_prompt else neg_content) + else: + final_negative_prompts.append(neg_content) + + formatted_responses.append(cleaned_response) + + except Exception as e: + logger.error(f"Error in batch {batch_idx}: {str(e)}") + formatted_responses.append(f"Error in batch {batch_idx}: {str(e)}") + final_negative_prompts.append(f"Error generating negative prompt for batch {batch_idx}") + + # Combine all responses + formatted_response = "\n".join(final_prompts) + neg_content = "\n".join(final_negative_prompts) + + # Update message history if needed + if keep_alive and formatted_response: + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": formatted_response}) + + return { + "Question": user_prompt, + "Response": formatted_response, + "Negative": neg_content, + "Tool_Output": None, + "Retrieved_Image": current_images, # Return original images + "Mask": mask_tensor + } + + except Exception as e: + logger.error(f"Error in normal strategy: {str(e)}") + # Return original images or placeholder on error + if images is not None: + current_images = images # Use original images + if mask is not None: + current_mask = mask + else: + # Create default mask matching image dimensions + current_mask = torch.ones((current_images.shape[0], 1, current_images.shape[2], current_images.shape[3]), + dtype=torch.float32, + device=current_images.device) + else: + current_images, current_mask = load_placeholder_image(self.placeholder_image_path) + + return { + "Question": user_prompt, + "Response": f"Error in processing: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": current_images, + "Mask": current_mask + } + + except Exception as e: + logger.error(f"Error in process_image: {str(e)}") + return { + "Question": kwargs.get("user_prompt", ""), + "Response": f"Error: {str(e)}", + "Negative": "", + "Tool_Output": None, + "Retrieved_Image": ( + images[0] + if images is not None and len(images) > 0 + else load_placeholder_image(self.placeholder_image_path)[0] + ), + "Mask": ( + torch.ones_like(images[0][:1]) + if images is not None and len(images) > 0 + else load_placeholder_image(self.placeholder_image_path)[1] + ), + } + + + def process_image_wrapper(self, **kwargs): + """Wrapper to handle async execution of process_image""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Ensure images is present in kwargs + if 'images' not in kwargs: + raise ValueError("Input images are required") + + # Ensure all other required parameters are present + required_params = ['llm_provider', 'llm_model', 'base_ip', 'port', 'user_prompt'] + missing_params = [p for p in required_params if p not in kwargs] + if missing_params: + raise ValueError(f"Missing required parameters: {', '.join(missing_params)}") + + # Get the result from process_image + result = loop.run_until_complete(self.process_image(**kwargs)) + + # Extract values in the correct order matching RETURN_TYPES + prompt = result.get("Response", "") # This is the formatted prompt + response = result.get("Question", "") # Original question/prompt + negative = result.get("Negative", "") + omni = result.get("Tool_Output") + retrieved_image = result.get("Retrieved_Image") + mask = result.get("Mask") + + # Ensure we have valid image and mask tensors + if retrieved_image is None or not isinstance(retrieved_image, torch.Tensor): + retrieved_image, mask = load_placeholder_image(self.placeholder_image_path) + + # Ensure mask has correct format + if mask is None: + mask = torch.ones((retrieved_image.shape[0], 1, retrieved_image.shape[2], retrieved_image.shape[3]), + dtype=torch.float32, + device=retrieved_image.device) + + # Return tuple matching RETURN_TYPES order: ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK") + return ( + response, # First STRING (question/prompt) + prompt, # Second STRING (generated response) + negative, # Third STRING (negative prompt) + omni, # OMNI + retrieved_image, # IMAGE + mask # MASK + ) + + except Exception as e: + logger.error(f"Error in process_image_wrapper: {str(e)}") + # Create fallback values + image_tensor, mask_tensor = load_placeholder_image(self.placeholder_image_path) + return ( + kwargs.get("user_prompt", ""), # Original prompt + f"Error: {str(e)}", # Error message as response + "", # Empty negative prompt + None, # No OMNI data + image_tensor, # Placeholder image + mask_tensor # Default mask + ) + + +NODE_CLASS_MAPPINGS = {"IF_PromptMkr": IFPrompt2Prompt} +NODE_DISPLAY_NAME_MAPPINGS = {"IF_PromptMkr": "IF Prompt to Prompt💬"} \ No newline at end of file diff --git a/IFSaveTextNode.py b/IFSaveTextNode.py index c4d1b30..20eaf9f 100644 --- a/IFSaveTextNode.py +++ b/IFSaveTextNode.py @@ -1,82 +1,82 @@ -import os -import csv -import json -import folder_paths -import uuid - -class IFSaveText: - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "question_input": ("STRING", {"forceInput": True}), - "response_input": ("STRING", {"forceInput": True}), - "negative_input": ("STRING", {"forceInput": True}), - #"turn": ("STRING", {"forceInput": True}), - }, - "optional": { - "save_file": ("BOOLEAN", {"default": False, "label_on": "Save Text", "label_off": "Don't Save"}), - "file_format": (["csv", "txt", "json"],), - "save_mode": (["create", "overwrite", "append"],), - }, - #"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING",) - RETURN_NAMES = ("Question", "Response", "Negative", "Turn",) - FUNCTION = "process_text" - OUTPUT_NODE = True - CATEGORY = "ImpactFrames💥🎞️" - - def process_text(self, question_input, negative_input, response_input, save_file=False, file_format="txt", save_mode="create"): - turn_id = str(uuid.uuid4()) - turn_data = {"id": turn_id, "question": question_input, "response": response_input, "negative": negative_input} - if save_file: - self.save_text_to_file(turn_data, file_format, save_mode) - - turn = f"ID: {turn_id}\nQuestion: {question_input}\nResponse: {response_input}\nNegative: {negative_input}" - return (question_input, response_input, negative_input, turn) - - def save_text_to_file(self, turn_data, file_format, save_mode): - save_text_dir = folder_paths.get_output_directory() - os.makedirs(save_text_dir, exist_ok=True) - file_path = os.path.join(save_text_dir, f"output.{file_format}") - - file_mode = "w" if save_mode in ["create", "overwrite"] else "a" - - if file_format == "csv": - with open(file_path, file_mode, newline='') as csvfile: - fieldnames = ['question', 'response'] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - - if save_mode == "create" or save_mode == "overwrite": - writer.writeheader() - writer.writerow(turn_data) - - elif file_format == "txt": - with open(file_path, file_mode) as txtfile: - txtfile.write(f"{turn_data}\n") - - elif file_format == "json": - with open(file_path, file_mode) as jsonfile: - if save_mode == "append": - try: - data = json.load(jsonfile) - except: - data = [] - data.append(turn_data) - jsonfile.seek(0) - else: - data = [turn_data] - json.dump(data, jsonfile, indent=4) - - """@classmethod - def IS_CHANGED(cls, turn_id, question_input, negative_input, response_input, turn, save_file, file_format, save_mode, unique_id=None, prompt=None, extra_pnginfo=None): - turn = f"ID: {turn_id}\nQuestion: {question_input}\nResponse: {response_input}\nNegative: {negative_input}" - return {"ui": {"string": [turn]}, "result": (turn,)}""" - -NODE_CLASS_MAPPINGS = {"IF_saveText": IFSaveText} -NODE_DISPLAY_NAME_MAPPINGS = {"IF_saveText": "IF Save Text📝"} +import os +import csv +import json +import folder_paths +import uuid + +class IFSaveText: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "question_input": ("STRING", {"forceInput": True}), + "response_input": ("STRING", {"forceInput": True}), + "negative_input": ("STRING", {"forceInput": True}), + #"turn": ("STRING", {"forceInput": True}), + }, + "optional": { + "save_file": ("BOOLEAN", {"default": False, "label_on": "Save Text", "label_off": "Don't Save"}), + "file_format": (["csv", "txt", "json"],), + "save_mode": (["create", "overwrite", "append"],), + }, + #"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING",) + RETURN_NAMES = ("Question", "Response", "Negative", "Turn",) + FUNCTION = "process_text" + OUTPUT_NODE = True + CATEGORY = "ImpactFrames💥🎞️" + + def process_text(self, question_input, negative_input, response_input, save_file=False, file_format="txt", save_mode="create"): + turn_id = str(uuid.uuid4()) + turn_data = {"id": turn_id, "question": question_input, "response": response_input, "negative": negative_input} + if save_file: + self.save_text_to_file(turn_data, file_format, save_mode) + + turn = f"ID: {turn_id}\nQuestion: {question_input}\nResponse: {response_input}\nNegative: {negative_input}" + return (question_input, response_input, negative_input, turn) + + def save_text_to_file(self, turn_data, file_format, save_mode): + save_text_dir = folder_paths.get_output_directory() + os.makedirs(save_text_dir, exist_ok=True) + file_path = os.path.join(save_text_dir, f"output.{file_format}") + + file_mode = "w" if save_mode in ["create", "overwrite"] else "a" + + if file_format == "csv": + with open(file_path, file_mode, newline='') as csvfile: + fieldnames = ['question', 'response'] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if save_mode == "create" or save_mode == "overwrite": + writer.writeheader() + writer.writerow(turn_data) + + elif file_format == "txt": + with open(file_path, file_mode) as txtfile: + txtfile.write(f"{turn_data}\n") + + elif file_format == "json": + with open(file_path, file_mode) as jsonfile: + if save_mode == "append": + try: + data = json.load(jsonfile) + except: + data = [] + data.append(turn_data) + jsonfile.seek(0) + else: + data = [turn_data] + json.dump(data, jsonfile, indent=4) + + """@classmethod + def IS_CHANGED(cls, turn_id, question_input, negative_input, response_input, turn, save_file, file_format, save_mode, unique_id=None, prompt=None, extra_pnginfo=None): + turn = f"ID: {turn_id}\nQuestion: {question_input}\nResponse: {response_input}\nNegative: {negative_input}" + return {"ui": {"string": [turn]}, "result": (turn,)}""" + +NODE_CLASS_MAPPINGS = {"IF_saveText": IFSaveText} +NODE_DISPLAY_NAME_MAPPINGS = {"IF_saveText": "IF Save Text📝"} diff --git a/IFTextTyperNode.py b/IFTextTyperNode.py index b089694..b565da9 100644 --- a/IFTextTyperNode.py +++ b/IFTextTyperNode.py @@ -1,24 +1,24 @@ -class IFTextTyper: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "text": ("STRING", {"multiline": True}) - } - } - - RETURN_TYPES = ("STRING",) - FUNCTION = "output_text" - OUTPUT_NODE = True - CATEGORY = "ImpactFrames💥🎞️" - - def output_text(self, text): - return (text,) - -NODE_CLASS_MAPPINGS = { - "IF_TextTyper": IFTextTyper -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "IF_TextTyper": "IF Text Typer✍️" +class IFTextTyper: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "text": ("STRING", {"multiline": True}) + } + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "output_text" + OUTPUT_NODE = True + CATEGORY = "ImpactFrames💥🎞️" + + def output_text(self, text): + return (text,) + +NODE_CLASS_MAPPINGS = { + "IF_TextTyper": IFTextTyper +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "IF_TextTyper": "IF Text Typer✍️" } \ No newline at end of file diff --git a/IFVisualizeGraphNode.py b/IFVisualizeGraphNode.py index e78da4b..74cdbae 100644 --- a/IFVisualizeGraphNode.py +++ b/IFVisualizeGraphNode.py @@ -1,47 +1,47 @@ -import json -import networkx as nx -import os - -class IFVisualizeGraphNode: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "graph_data": ("STRING", {"tooltip": "GraphML file path"}), - }, - "optional": { - "layout": (["spring", "circular", "random", "shell", "spectral"], {"default": "spring"}), - } - } - - RETURN_TYPES = () - FUNCTION = "visualize_graph" - CATEGORY = "ImpactFrames💥🎞️" - OUTPUT_NODE = True - - def visualize_graph(self, graph_data, layout="spring"): - print(f"Visualizing graph: {graph_data}, layout: {layout}") - try: - if not os.path.exists(graph_data): - print(f"GraphML file not found: {graph_data}") - return {}, {"ui": {"error": f"GraphML file not found: {graph_data}"}} - - G = nx.read_graphml(graph_data) - graph_json = json.dumps(nx.node_link_data(G)) - print(f"Graph JSON (first 100 chars): {graph_json[:100]}...") - - return {}, {"ui": {"graph": graph_json, "layout": layout}} - - except Exception as e: - import traceback - error_message = f"Error: {str(e)}\nTraceback:\n{traceback.format_exc()}" - print(error_message) - return {}, {"ui": {"error": error_message}} - -NODE_CLASS_MAPPINGS = { - "IF_VisualizeGraph": IFVisualizeGraphNode -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "IF_VisualizeGraph": "IF Visualize Graph🕸️" +import json +import networkx as nx +import os + +class IFVisualizeGraphNode: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "graph_data": ("STRING", {"tooltip": "GraphML file path"}), + }, + "optional": { + "layout": (["spring", "circular", "random", "shell", "spectral"], {"default": "spring"}), + } + } + + RETURN_TYPES = () + FUNCTION = "visualize_graph" + CATEGORY = "ImpactFrames💥🎞️" + OUTPUT_NODE = True + + def visualize_graph(self, graph_data, layout="spring"): + print(f"Visualizing graph: {graph_data}, layout: {layout}") + try: + if not os.path.exists(graph_data): + print(f"GraphML file not found: {graph_data}") + return {}, {"ui": {"error": f"GraphML file not found: {graph_data}"}} + + G = nx.read_graphml(graph_data) + graph_json = json.dumps(nx.node_link_data(G)) + print(f"Graph JSON (first 100 chars): {graph_json[:100]}...") + + return {}, {"ui": {"graph": graph_json, "layout": layout}} + + except Exception as e: + import traceback + error_message = f"Error: {str(e)}\nTraceback:\n{traceback.format_exc()}" + print(error_message) + return {}, {"ui": {"error": error_message}} + +NODE_CLASS_MAPPINGS = { + "IF_VisualizeGraph": IFVisualizeGraphNode +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "IF_VisualizeGraph": "IF Visualize Graph🕸️" } \ No newline at end of file diff --git a/README.md b/README.md index abaf61d..3202040 100644 --- a/README.md +++ b/README.md @@ -1,153 +1,151 @@ -[![ComfyUI-IF_AI_tools](https://img.youtube.com/vi/QAnapTWnawU/0.jpg)](https://youtu.be/QAnapTWnawU?si=Uomv_NXT2n2Mg9rG) - -# ComfyUI-IF_AI_tools - -ComfyUI-IF_AI_tools is a set of custom nodes to Run Local and API LLMs and LMMs, Features OCR-RAG (Bialdy), nanoGraphRAG, Supervision Object Detection, supports Ollama, LlamaCPP LMstudio, Koboldcpp, TextGen, Transformers or via APIs Anthropic, Groq, OpenAI, Google Gemini, Mistral, xAI and create your own charcters assistants (SystemPrompts) with custom presets and muchmore - - -# Prerequisite Installation (Poppler) - -To ensure compatibility and functionality with all tools, you may need `poppler` for PDF-related operations. Use `scoop` to install `poppler` on Windows: - -### Step 1: Install `scoop` (if not already installed) -If you haven't installed `scoop` yet, run the following command in **PowerShell**: - -```powershell -iwr -useb get.scoop.sh | iex -``` -Step 2: Install poppler with scoop -Once scoop is installed, you can install poppler by running: - -windows 10+ istall scoop and then -```powershell -scoop install poppler -``` -Debian/Ubuntu -```bash -sudo apt-get install poppler-utils -``` -MacOS -```bash -brew install poppler -``` - -check is working -```powershell -pdftotext -v -``` - -### Install Ollama - -You can technically use any LLM API that you want, but for the best expirience install Ollama and set it up. -- Visit [ollama.com](https://ollama.com) for more information. - -To install Ollama models just open CMD or any terminal and type the run command follow by the model name such as -```powershell -ollama run llama3.2-vision -``` -If you want to use omost -```bash -ollama run impactframes/dolphin_llama3_omost -``` -if you need a good smol model -```bash -ollama run ollama run llama3.2 -``` - -Optionally Set enviromnet variables for any of your favourite LLM API keys "XAI_API_KEY", "GOOGLE_API_KEY", "ANTHROPIC_API_KEY", "MISTRAL_API_KEY", "OPENAI_API_KEY" or "GROQ_API_KEY" with those names or otherwise -it won't pick it up you can also use .env file to store your keys - -## Features -_[NEW]_ nanoGraphRAG, -_[NEW]_ OCR-RAG ColPali & ColQwen (Bialdy), -_[NEW]_ Supervision Object Detection Florence2 -_[NEW]_ Endpoints xAI, Transformers, -_[NEW]_ IF_Assistants System Prompts with Reasoning/Reflection/Reward Templates and custom presets - -- Gemini, Groq, Mistral, OpenAI, Anthropic, Google, xAI, Transformers, Koboldcpp, TextGen, LlamaCPP, LMstudio, Ollama -- Omost_tool the first tool -- Vision Models Haiku, Florence2 -- [Ollama-Omost]https://ollama.com/impactframes/dolphin_llama3_omost can be 2x to 3x faster than other Omost Models -LLama3 and Phi3 IF_AI Prompt mkr models released -![ComfyUI_00021_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/fac9fb38-66ac-431b-8ef9-b0fee5d0e5dc) - -`ollama run impactframes/llama3_ifai_sd_prompt_mkr_q4km:latest` - -`ollama run impactframes/ifai_promptmkr_dolphin_phi3:latest` - -https://huggingface.co/impactframes/llama3_if_ai_sdpromptmkr_q4km - -https://huggingface.co/impactframes/ifai_promptmkr_dolphin_phi3_gguf - - -## Installation -1. Open the manager search for IF_AI_tools and install - -### Install ComfyUI-IF_AI_tools -hardest way - -1. Navigate to your ComfyUI `custom_nodes` folder, type `CMD` on the address bar to open a command prompt, - and run the following command to clone the repository: - ```bash - git clone https://github.com/if-ai/ComfyUI-IF_AI_tools.git - ``` -OR -1. In ComfyUI protable version just dounle click `embedded_install.bat` or type `CMD` on the address bar on the newly created `custom_nodes\ComfyUI-IF_AI_tools` folder type - ```bash - H:\ComfyUI_windows_portable\python_embeded\python.exe -m pip install -r requirements.txt - ``` - replace `C:\` for your Drive letter where you have the ComfyUI_windows_portable directory - -2. On custom environment activate the environment and move to the newly created ComfyUI-IF_AI_tools - ```bash - cd ComfyUI-IF_AI_tools - python -m pip install -r requirements.txt - ``` - -## Related Tools -- [IF_prompt_MKR](https://github.com/if-ai/IF_prompt_MKR) -- A similar tool available for Stable Diffusion WebUI - -## Videos -AIFuzz made a great video usining ollama and IF_AI tools - -[![AIFuzz](https://img.youtube.com/vi/nZx5g3TGsNc/0.jpg)](https://youtu.be/nZx5g3TGsNc?si=DFIqFuPoyKY1qJ2n) - -Also Future thinker @ Benji Thankyou both for putting out this awesome videos - -[![Future Thinker @Benji](https://img.youtube.com/vi/EQZWyn9eCFE/0.jpg)](https://youtu.be/EQZWyn9eCFE?si=jgC28GL7bwFWj_sK) - - -## Example using normal Model -ancient Megastructure, small lone figure -'A dwarfed figure standing atop an ancient megastructure, worn stone towering overhead. Underneath the dim moonlight, intricate engravings adorn the crumbling walls. Overwhelmed by the sheer size and age of the structure, the small figure appears lost amidst the weathered stone behemoth. The background reveals a dark landscape, dotted with faint twinkles from other ancient structures, scattered across the horizon. The silent air is only filled with the soft echoes of distant whispers, carrying secrets of times long past. ethereal-fantasy-concept-art, magical-ambiance, magnificent, celestial, ethereal-lighting, painterly, epic, majestic, dreamy-atmosphere, otherworldly, mystic-elements, surreal, immersive-detail' -![_IF_prompt_Mkr__00011_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/08dde522-f541-49f4-aa6b-e0653f13aa52) -![_IF_prompt_Mkr__00012_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/ec3ef715-fbe6-4ba0-80f8-00bf10f56f7b) -![_IF_prompt_Mkr__00010_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/e4dc671b-8eea-47f3-84ef-876e5938e120) -![_IF_prompt_Mkr__00014_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/d0b436cd-c4a8-41a2-83ad-34d8c50bb39b) - -## TODO -- Undergoing full refactor the next update will have three nodes one for general purpose chat, -one for Images creation, one complementary carrying the embeddings and auxiliary tools -- [ ] Fix Bugs and make it work on latest ComfyUI -- [ ] Fix Graph Visualizer Node -- [ ] Tweak IF_Assistants and Templates -- [ ] FrontEnd for IF_Assistants and Chat -- [ ] Node and workflow creator -- [ ] two additional Endpoints one API and one Local -- [ ] Add New workflows -- [ ] Image Generation, Text 2 Image, Image to Image, Video Generation - -## Support -If you find this tool useful, please consider supporting my work by: -- Starring the repository on GitHub: [ComfyUI-IF_AI_tools](https://github.com/if-ai/ComfyUI-IF_AI_tools) -- Subscribing to my YouTube channel: [Impact Frames](https://youtube.com/@impactframes?si=DrBu3tOAC2-YbEvc) -- Follow me on X: [Impact Frames X](https://x.com/impactframesX) -- Supporting me on Ko-fi: [Impact Frames Ko-fi](https://ko-fi.com/impactframes) -- Becoming a patron on Patreon: [Impact Frames Patreon](https://patreon.com/ImpactFrames) -Thank You! - -:IFAItools_comfy - - - - +[![ComfyUI-IF_AI_tools](https://img.youtube.com/vi/QAnapTWnawU/0.jpg)](https://youtu.be/QAnapTWnawU?si=Uomv_NXT2n2Mg9rG) + +# ComfyUI-IF_AI_tools + +ComfyUI-IF_AI_tools is a set of custom nodes to Run Local and API LLMs and LMMs, Features OCR-RAG (Bialdy), nanoGraphRAG, Supervision Object Detection, supports Ollama, LlamaCPP LMstudio, Koboldcpp, TextGen, Transformers or via APIs Anthropic, Groq, OpenAI, Google Gemini, Mistral, xAI and create your own charcters assistants (SystemPrompts) with custom presets and muchmore + + +# Prerequisite Installation (Poppler) + +To ensure compatibility and functionality with all tools, you may need `poppler` for PDF-related operations. Use `scoop` to install `poppler` on Windows: + +### Step 1: Install `scoop` (if not already installed) +If you haven't installed `scoop` yet, run the following command in **PowerShell**: + +```powershell +iwr -useb get.scoop.sh | iex +``` +Step 2: Install poppler with scoop +Once scoop is installed, you can install poppler by running: + +windows 10+ istall scoop and then +```powershell +scoop install poppler +``` +Debian/Ubuntu +```bash +sudo apt-get install poppler-utils +``` +MacOS +```bash +brew install poppler +``` + +check is working +```powershell +pdftotext -v +``` + +### Install Ollama + +You can technically use any LLM API that you want, but for the best expirience install Ollama and set it up. +- Visit [ollama.com](https://ollama.com) for more information. + +To install Ollama models just open CMD or any terminal and type the run command follow by the model name such as +```powershell +ollama run llama3.2-vision +``` +If you want to use omost +```bash +ollama run impactframes/dolphin_llama3_omost +``` +if you need a good smol model +```bash +ollama run ollama run llama3.2 +``` + +Optionally Set enviromnet variables for any of your favourite LLM API keys "XAI_API_KEY", "GOOGLE_API_KEY", "ANTHROPIC_API_KEY", "MISTRAL_API_KEY", "OPENAI_API_KEY" or "GROQ_API_KEY" with those names or otherwise +it won't pick it up you can also use .env file to store your keys + +## Features +_[NEW]_ nanoGraphRAG, +_[NEW]_ OCR-RAG ColPali & ColQwen (Bialdy), +_[NEW]_ Supervision Object Detection Florence2 +_[NEW]_ Endpoints xAI, Transformers, +_[NEW]_ IF_Assistants System Prompts with Reasoning/Reflection/Reward Templates and custom presets + +- Gemini, Groq, Mistral, OpenAI, Anthropic, Google, xAI, Transformers, Koboldcpp, TextGen, LlamaCPP, LMstudio, Ollama +- Omost_tool the first tool +- Vision Models Haiku, Florence2 +- [Ollama-Omost]https://ollama.com/impactframes/dolphin_llama3_omost can be 2x to 3x faster than other Omost Models +LLama3 and Phi3 IF_AI Prompt mkr models released +![ComfyUI_00021_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/fac9fb38-66ac-431b-8ef9-b0fee5d0e5dc) + +`ollama run impactframes/llama3_ifai_sd_prompt_mkr_q4km:latest` + +`ollama run impactframes/ifai_promptmkr_dolphin_phi3:latest` + +https://huggingface.co/impactframes/llama3_if_ai_sdpromptmkr_q4km + +https://huggingface.co/impactframes/ifai_promptmkr_dolphin_phi3_gguf + + +## Installation +1. Open the manager search for IF_AI_tools and install + +### Install ComfyUI-IF_AI_tools -hardest way + +1. Navigate to your ComfyUI `custom_nodes` folder, type `CMD` on the address bar to open a command prompt, + and run the following command to clone the repository: + ```bash + git clone https://github.com/if-ai/ComfyUI-IF_AI_tools.git + ``` +OR +1. In ComfyUI protable version just dounle click `embedded_install.bat` or type `CMD` on the address bar on the newly created `custom_nodes\ComfyUI-IF_AI_tools` folder type + ```bash + H:\ComfyUI_windows_portable\python_embeded\python.exe -m pip install -r requirements.txt + ``` + replace `C:\` for your Drive letter where you have the ComfyUI_windows_portable directory + +2. On custom environment activate the environment and move to the newly created ComfyUI-IF_AI_tools + ```bash + cd ComfyUI-IF_AI_tools + python -m pip install -r requirements.txt + ``` + +## Related Tools +- [IF_prompt_MKR](https://github.com/if-ai/IF_prompt_MKR) +- A similar tool available for Stable Diffusion WebUI + +## Videos +AIFuzz made a great video usining ollama and IF_AI tools + +[![AIFuzz](https://img.youtube.com/vi/nZx5g3TGsNc/0.jpg)](https://youtu.be/nZx5g3TGsNc?si=DFIqFuPoyKY1qJ2n) + +Also Future thinker @ Benji Thankyou both for putting out this awesome videos + +[![Future Thinker @Benji](https://img.youtube.com/vi/EQZWyn9eCFE/0.jpg)](https://youtu.be/EQZWyn9eCFE?si=jgC28GL7bwFWj_sK) + + +## Example using normal Model +ancient Megastructure, small lone figure +'A dwarfed figure standing atop an ancient megastructure, worn stone towering overhead. Underneath the dim moonlight, intricate engravings adorn the crumbling walls. Overwhelmed by the sheer size and age of the structure, the small figure appears lost amidst the weathered stone behemoth. The background reveals a dark landscape, dotted with faint twinkles from other ancient structures, scattered across the horizon. The silent air is only filled with the soft echoes of distant whispers, carrying secrets of times long past. ethereal-fantasy-concept-art, magical-ambiance, magnificent, celestial, ethereal-lighting, painterly, epic, majestic, dreamy-atmosphere, otherworldly, mystic-elements, surreal, immersive-detail' +![_IF_prompt_Mkr__00011_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/08dde522-f541-49f4-aa6b-e0653f13aa52) +![_IF_prompt_Mkr__00012_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/ec3ef715-fbe6-4ba0-80f8-00bf10f56f7b) +![_IF_prompt_Mkr__00010_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/e4dc671b-8eea-47f3-84ef-876e5938e120) +![_IF_prompt_Mkr__00014_](https://github.com/if-ai/ComfyUI-IF_AI_tools/assets/21185218/d0b436cd-c4a8-41a2-83ad-34d8c50bb39b) + +## TODO +- [ ] Fix Bugs and make it work on latest ComfyUI +- [ ] Fix Graph Visualizer Node +- [ ] Tweak IF_Assistants and Templates +- [ ] FrontEnd for IF_Assistants and Chat +- [ ] Node and workflow creator +- [ ] two additional Endpoints one API and one Local +- [ ] Add New workflows +- [ ] Image Generation, Text 2 Image, Image to Image, Video Generation + +## Support +If you find this tool useful, please consider supporting my work by: +- Starring the repository on GitHub: [ComfyUI-IF_AI_tools](https://github.com/if-ai/ComfyUI-IF_AI_tools) +- Subscribing to my YouTube channel: [Impact Frames](https://youtube.com/@impactframes?si=DrBu3tOAC2-YbEvc) +- Follow me on X: [Impact Frames X](https://x.com/impactframesX) +- Supporting me on Ko-fi: [Impact Frames Ko-fi](https://ko-fi.com/impactframes) +- Becoming a patron on Patreon: [Impact Frames Patreon](https://patreon.com/ImpactFrames) +Thank You! + +:IFAItools_comfy + + + + diff --git a/__init__.py b/__init__.py index cad9756..fca3c32 100644 --- a/__init__.py +++ b/__init__.py @@ -1,90 +1,94 @@ -import os -import importlib.util -import glob -import shutil -import sys -import folder_paths -from aiohttp import web - -from .IFPromptMkrNode import IFPrompt2Prompt -from .IFImagePromptNode import IFImagePrompt -from .IFSaveTextNode import IFSaveText -from .IFDisplayTextNode import IFDisplayText -from .IFChatPromptNode import IFChatPrompt -from .IFDisplayOmniNode import IFDisplayOmni -from .IFTextTyperNode import IFTextTyper -from .IFVisualizeGraphNode import IFVisualizeGraphNode -from .IFStepCounterNode import IFCounter -from .IFJoinTextNode import IFJoinText -from .IFLoadImagesNode import IFLoadImages -from .send_request import * - -# Try to import omost from the current directory -# Add the current directory to sys.path -current_dir = os.path.dirname(os.path.abspath(__file__)) -if current_dir not in sys.path: - sys.path.insert(0, current_dir) -#print(f"Current directory: {current_dir}") -#print(f"Files in current directory: {os.listdir(current_dir)}") -try: - from .omost import omost_function - print("Successfully imported omost_function from omost.py in the current directory") -except ImportError as e: - print(f"Error importing omost from current directory: {e}") - - # If import fails, try to import from the parent directory - parent_dir = os.path.dirname(current_dir) - parent_dir_name = os.path.basename(parent_dir) - if parent_dir_name == 'ComfyUI_IF_AI_tools': - sys.path.insert(0, parent_dir) - try: - from omost import omost_function - print(f"Successfully imported omost_function from {parent_dir}/omost.py") - except ImportError as e: - print(f"Error importing omost from parent directory: {e}") - print(f"Current sys.path: {sys.path}") - raise -class OmniType(str): - """A special string type that acts as a wildcard for universal input/output. - It always evaluates as equal in comparisons.""" - def __ne__(self, __value: object) -> bool: - return False - -OMNI = OmniType("*") - - -NODE_CLASS_MAPPINGS = { - "IF_PromptMkr": IFPrompt2Prompt, - "IF_ImagePrompt": IFImagePrompt, - "IF_SaveText": IFSaveText, - "IF_DisplayText": IFDisplayText, - "IF_ChatPrompt": IFChatPrompt, - "IF_DisplayOmni": IFDisplayOmni, - "IF_TextTyper": IFTextTyper, - "IF_VisualizeGraph": IFVisualizeGraphNode, - "IF_StepCounter": IFCounter, - "IF_JoinText": IFJoinText, - "IF_LoadImages": IFLoadImages, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "IF_PromptMkr": "IF Prompt to Prompt💬", - "IF_ImagePrompt": "IF Image to Prompt🖼️", - "IF_SaveText": "IF Save Text📝", - "IF_DisplayText": "IF Display Text📟", - "IF_ChatPrompt": "IF Chat Prompt👨‍💻", - "IF_DisplayOmni": "IF Display Omni🔍", - "IF_TextTyper": "IF Text Typer✍️", - "IF_VisualizeGraph": "IF Visualize Graph🕸️", - "IF_StepCounter": "IF Step Counter 🔢", - "IF_JoinText": "IF Join Text 📝", - "IF_LoadImages": "IF Load Images🖼️", -} - -WEB_DIRECTORY = "./web" -__all__ = [ - "NODE_CLASS_MAPPINGS", - "NODE_DISPLAY_NAME_MAPPINGS", - "WEB_DIRECTORY", - "omost_function" - ] +import os +import importlib.util +import glob +import shutil +import sys +import folder_paths +from aiohttp import web + +from .IFChatPromptNode import IFChatPrompt +from .IFImagePromptNode import IFImagePrompt +from .IFPromptMkrNode import IFPrompt2Prompt +from .IFSupervisionNode import IFSupervision +from .IFDisplayTextWildcardNode import IFDisplayTextWildcard +from .IFSaveTextNode import IFSaveText +from .IFDisplayTextNode import IFDisplayText +from .IFDisplayOmniNode import IFDisplayOmni +from .IFTextTyperNode import IFTextTyper +from .IFVisualizeGraphNode import IFVisualizeGraphNode +from .IFStepCounterNode import IFCounter +from .IFJoinTextNode import IFJoinText +from .IFLoadImagesNode import IFLoadImagess +from .send_request import * + +# Try to import omost from the current directory +# Add the current directory to sys.path +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.insert(0, current_dir) +#print(f"Current directory: {current_dir}") +#print(f"Files in current directory: {os.listdir(current_dir)}") +try: + from .omost import omost_function + print("Successfully imported omost_function from omost.py in the current directory") +except ImportError as e: + print(f"Error importing omost from current directory: {e}") + + # If import fails, try to import from the parent directory + parent_dir = os.path.dirname(current_dir) + parent_dir_name = os.path.basename(parent_dir) + if parent_dir_name == 'ComfyUI-IF_AI_tools': + sys.path.insert(0, parent_dir) + try: + from omost import omost_function + print(f"Successfully imported omost_function from {parent_dir}/omost.py") + except ImportError as e: + print(f"Error importing omost from parent directory: {e}") + print(f"Current sys.path: {sys.path}") + raise +class OmniType(str): + """A special string type that acts as a wildcard for universal input/output. + It always evaluates as equal in comparisons.""" + def __ne__(self, __value: object) -> bool: + return False + +OMNI = OmniType("*") + +NODE_CLASS_MAPPINGS = { + "IF_ChatPrompt": IFChatPrompt, + "IF_PromptMkr": IFPrompt2Prompt, + "IF_ImagePrompt": IFImagePrompt, + "IF_Supervision": IFSupervision, + "IF_SaveText": IFSaveText, + "IF_DisplayText": IFDisplayText, + "IF_DisplayTextWildcard": IFDisplayTextWildcard, + "IF_DisplayOmni": IFDisplayOmni, + "IF_TextTyper": IFTextTyper, + "IF_VisualizeGraph": IFVisualizeGraphNode, + "IF_StepCounter": IFCounter, + "IF_JoinText": IFJoinText, + "IF_LoadImagesS": IFLoadImagess, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "IF_ChatPrompt": "IF Chat Prompt👨‍💻", + "IF_PromptMkr": "IF Prompt Maker🎨", + "IF_ImagePrompt": "IF Image to Prompt🖼️", + "IF_SaveText": "IF Save Text📝", + "IF_DisplayText": "IF Display Text📟", + "IF_DisplayTextWildcard": "IF Display Text Wildcard📟", + "IF_DisplayOmni": "IF Display Omni🔍", + "IF_TextTyper": "IF Text Typer✍️", + "IF_VisualizeGraph": "IF Visualize Graph🕸️", + "IF_StepCounter": "IF Step Counter 🔢", + "IF_JoinText": "IF Join Text 📝", + "IF_LoadImagesS": "IF Load Images S 🖼️" +} + +WEB_DIRECTORY = "./web" +__all__ = [ + "NODE_CLASS_MAPPINGS", + "NODE_DISPLAY_NAME_MAPPINGS", + "WEB_DIRECTORY", + "omost_function" + ] diff --git a/agent_tool.py b/agent_tool.py index 5fe0556..e9a4b82 100644 --- a/agent_tool.py +++ b/agent_tool.py @@ -23,7 +23,7 @@ def __init__(self, name, description, system_prompt, default_engine, default_mod def load(self): # Construct the path to the ComfyUI-IF_AI_tools directory - if_ai_tools_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools") + if_ai_tools_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools") # Add the ComfyUI-IF_AI_tools directory to sys.path if if_ai_tools_dir not in sys.path: diff --git a/colpaliRAG_module.py b/colpaliRAG_module.py index bb7de1c..1578ccb 100644 --- a/colpaliRAG_module.py +++ b/colpaliRAG_module.py @@ -1,666 +1,666 @@ -import os -import logging -import torch -import builtins -from byaldi import RAGMultiModalModel -from .graphRAG_module import GraphRAGapp -from typing import Tuple, Optional, Dict, Union, List, Any -from pathlib import Path -import numpy as np -from PIL import Image -from io import BytesIO -import base64 -from .send_request import send_request -import asyncio -import json -import shutil -from PIL import Image -from io import BytesIO -from pdf2image import convert_from_path -from .utils import get_api_key, load_placeholder_image - -import comfy.model_management as mm -from comfy.utils import ProgressBar -import folder_paths -from .transformers_api import TransformersModelManager - -import sys - -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Global variable for model caching -if not hasattr(builtins, 'global_colpali_model'): - builtins.global_colpali_model = None - -class colpaliRAGapp: - def __init__(self): - self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "rag") - self._rag_root_dir = None - self._input_dir = None - self.graphrag_app = GraphRAGapp() - self.transformers_api = TransformersModelManager() - self.cached_index_model = None # Add this line to store the loaded index - self.cached_index_name = None # Add this line to track which index is loaded - - @property - def rag_root_dir(self): - return self._rag_root_dir - - @rag_root_dir.setter - def rag_root_dir(self, value): - self._rag_root_dir = value - self._input_dir = os.path.join(self.rag_dir, value, "input") if value else None - logger.debug(f"rag_root_dir setter: set to {self._rag_root_dir}") - logger.debug(f"input_dir set to {self._input_dir}") - if self._input_dir: - os.makedirs(self._input_dir, exist_ok=True) - logger.debug(f"Created input directory: {self._input_dir}") - - def set_rag_root_dir(self, rag_folder_name): - if rag_folder_name: - new_rag_root_dir = os.path.join(self.rag_dir, rag_folder_name) - else: - new_rag_root_dir = os.path.join(self.rag_dir, "rag_data") - - self._rag_root_dir = rag_folder_name # Ensure rag_folder_name is a string - self._input_dir = os.path.join(new_rag_root_dir, "input") - - # Ensure directories exist - os.makedirs(self._rag_root_dir, exist_ok=True) - os.makedirs(self._input_dir, exist_ok=True) - - logger.debug(f"set_rag_root_dir: rag_root_dir set to {self._rag_root_dir}") - logger.debug(f"set_rag_root_dir: input_dir set to {self._input_dir}") - return self._rag_root_dir, self._input_dir - - @classmethod - def get_colpali_model(cls, query_type): - logger.debug(f"Attempting to get ColPali model. Current state: {'Loaded' if builtins.global_colpali_model is not None else 'Not loaded'}") - - if query_type == "colqwen2": - model_path = os.path.join(folder_paths.models_dir, "LLM", "colqwen2-v0.1") - elif query_type == "colpali-v1.2": - model_path = os.path.join(folder_paths.models_dir, "LLM", "colpali-v1.2") - elif query_type == "colpali": - model_path = os.path.join(folder_paths.models_dir, "LLM", "colpali") - else: - logger.error(f"Invalid query type: {query_type}") - return None - - if builtins.global_colpali_model is None: - torch.cuda.empty_cache() - logger.info(f"Loading ColPali model from {model_path}") - try: - builtins.global_colpali_model = RAGMultiModalModel.from_pretrained( - model_path, - device="cuda", - verbose=1, - ) - logger.info("ColPali model loaded and cached globally on CUDA") - except Exception as e: - logger.error(f"Error loading ColPali model: {str(e)}") - builtins.global_colpali_model = None - else: - logger.info("Using existing globally cached ColPali model") - return builtins.global_colpali_model - - async def insert(self): - try: - logger.debug("Starting insert function...") - settings = self.graphrag_app.load_settings() - - index_path = os.path.join(self.rag_dir, settings.get("rag_folder_name"), "input") - index_name = settings.get("rag_folder_name") - query_type = settings.get("query_type", "colqwen") - - # Get the model - colpali_model = self.get_colpali_model(query_type) - if colpali_model is None: - logger.error("Failed to load ColPali model for indexing") - return False - - # Check if index already exists - index_folder = os.path.join(".byaldi", index_name) - is_existing_index = os.path.exists(index_folder) - - # Patch the processor's process_images method - original_process_images = colpali_model.model.processor.process_images - - def process_images_wrapper(*args, **kwargs): - result = original_process_images(*args, **kwargs) - processed = {} - for k, v in result.items(): - if torch.is_tensor(v): - if 'pixel_values' in k: - processed[k] = v.to(dtype=torch.bfloat16) - else: - processed[k] = v.to(dtype=torch.long) - else: - processed[k] = v - return processed - - try: - # Apply the patch - colpali_model.model.processor.process_images = process_images_wrapper - - # Create new index or add to existing one - if is_existing_index: - logger.info(f"Found existing index: {index_name}, loading it first...") - # Load existing index - colpali_model = self.get_colpali_model(query_type) - colpali_model = colpali_model.from_index( - index_name, - index_root=".byaldi", - device="cuda", - verbose=1 - ) - - # Get list of already indexed files - indexed_files = set() - if hasattr(colpali_model.model, 'doc_ids_to_file_names'): - indexed_files = set(colpali_model.model.doc_ids_to_file_names.values()) - - # Process new files only - new_files = [] - for file in os.listdir(index_path): - file_path = os.path.join(index_path, file) - if file_path not in indexed_files: - new_files.append(file_path) - - if new_files: - logger.info(f"Adding {len(new_files)} new documents to existing index") - colpali_model.index( - input_path=new_files, - index_name=index_name, - store_collection_with_index=False, - overwrite=False, - max_image_width=1024, - max_image_height=1024, - ) - else: - logger.info("No new documents to add to the index") - - else: - logger.info(f"Creating new index: {index_name}") - colpali_model.index( - input_path=index_path, - index_name=index_name, - store_collection_with_index=False, - overwrite=False, - max_image_width=1024, - max_image_height=1024, - ) - - # Remove extra folder if it exists - extra_folder = os.path.join(self.comfy_dir, settings.get("rag_folder_name")) - if os.path.exists(extra_folder) and os.path.isdir(extra_folder): - logger.debug(f"Removing extra folder: {extra_folder}") - shutil.rmtree(extra_folder) - - return True - - finally: - # Restore original process_images method - colpali_model.model.processor.process_images = original_process_images - - except Exception as e: - logger.error(f"Error during indexing: {str(e)}") - return False - - def load_indexed_images(self, index_name, size): - """Load previously indexed images from the stored index and convert new PDFs""" - try: - index_path = os.path.join(self.rag_dir, index_name, "input") - images_path = os.path.join(self.rag_dir, index_name, "converted_images") - os.makedirs(images_path, exist_ok=True) - - # Get sorted list of PDF files to ensure consistent ordering - pdf_files = sorted([f for f in os.listdir(index_path) if f.lower().endswith('.pdf')]) - - if not pdf_files: - logger.warning("No PDF files found in the input directory.") - return None - - all_images = {} - - # Process each PDF file - for doc_id, pdf_file in enumerate(pdf_files): - pdf_name = os.path.splitext(pdf_file)[0] - pdf_images_dir = os.path.join(images_path, pdf_name) - pdf_path = os.path.join(index_path, pdf_file) - - logger.debug(f"Processing PDF {pdf_file} as doc_id {doc_id}") - - # Check if conversion is needed - needs_conversion = True - if os.path.exists(pdf_images_dir) and os.listdir(pdf_images_dir): - pdf_mtime = os.path.getmtime(pdf_path) - newest_image = max( - os.path.getmtime(os.path.join(pdf_images_dir, f)) - for f in os.listdir(pdf_images_dir) - if f.endswith('.png') - ) - needs_conversion = pdf_mtime > newest_image - - if needs_conversion: - logger.info(f"Converting PDF: {pdf_file}") - os.makedirs(pdf_images_dir, exist_ok=True) - images = convert_from_path( - pdf_path, - thread_count=os.cpu_count() - 1, - fmt='png', - paths_only=False, - size=size, - ) - - # Save converted images with consistent naming - for page_num, img in enumerate(images, 1): - img_path = os.path.join(pdf_images_dir, f"page_{page_num:03d}.png") - img.save(img_path, "PNG") - logger.debug(f"Saved {img_path}") - - # Load images in correct order - page_files = sorted( - [f for f in os.listdir(pdf_images_dir) if f.endswith('.png')], - key=lambda x: int(x.split('_')[1].split('.')[0]) # Sort by page number - ) - - images = [] - for page_file in page_files: - img_path = os.path.join(pdf_images_dir, page_file) - try: - img = Image.open(img_path) - if img.mode != 'RGB': - img = img.convert('RGB') - images.append(img) - logger.debug(f"Loaded {img_path}") - except Exception as e: - logger.error(f"Error loading {img_path}: {e}") - continue - - all_images[doc_id] = images - logger.debug(f"Loaded {len(images)} pages for document {doc_id} ({pdf_file})") - - return all_images - - except Exception as e: - logger.error(f"Error loading indexed images: {str(e)}") - return None - - def get_top_results(self, results, all_images, index_name, llm_provider="transformers"): - """ - Get relevant images and their metadata based on search results. - Returns lists of images and masks along with result information. - - Args: - results: Search results to process - all_images: Dictionary of loaded images - index_name: Name of the index being used - llm_provider: The LLM provider being used (default: "transformers") - """ - top_results_images = [] - top_results_masks = [] - result_info = [] - - try: - # Get list of PDF files - index_path = os.path.join(self.rag_dir, index_name, "input") - pdf_files = sorted([f for f in os.listdir(index_path) if f.lower().endswith('.pdf')]) - - # Sort and filter results with safe score extraction - def get_score(result): - try: - return float(result.score) if hasattr(result, 'score') else 0.0 - except (ValueError, TypeError): - return 0.0 - - sorted_results = sorted(results, key=get_score, reverse=True) - - # For non-standard LLM providers, only take the highest scoring result - if llm_provider.lower() not in ["transformers", "openai", "anthropic"]: - sorted_results = sorted_results[:1] - - logger.debug("Processing sorted results:") - for r in sorted_results: - score = get_score(r) - try: - doc_id = int(r.doc_id) if isinstance(r.doc_id, str) else r.doc_id - page_num = int(r.page_num) if isinstance(r.page_num, str) else r.page_num - logger.debug(f"Doc: {doc_id}, Page: {page_num}, Score: {score}") - except (ValueError, TypeError) as e: - logger.warning(f"Invalid result format: {e}") - continue - - # Create a mapping of original doc_ids to ensure correct ordering - doc_id_map = {i: os.path.splitext(pdf_file)[0] for i, pdf_file in enumerate(pdf_files)} - - for result in sorted_results: - # Updated integer conversion logic - try: - doc_id = int(result.doc_id) if isinstance(result.doc_id, str) else result.doc_id - page_num = int(result.page_num) if isinstance(result.page_num, str) else result.page_num - except (ValueError, TypeError) as e: - logger.warning(f"Invalid doc_id or page_num format: {e}") - continue - - # Validate document ID - if doc_id not in all_images: - logger.warning(f"Document ID {doc_id} not found in loaded images") - continue - - # Validate page number (convert to 0-based index) - page_idx = page_num - 1 - if page_idx < 0 or page_idx >= len(all_images[doc_id]): - logger.warning(f"Invalid page {page_num} for document {doc_id}") - continue - - # Get image for this result - try: - image = all_images[doc_id][page_idx] - logger.debug(f"Retrieved image for doc {doc_id} ('{doc_id_map[doc_id]}'), page {page_num}") - - # Ensure image is in RGB format - if image.mode != 'RGB': - image = image.convert('RGB') - - # Convert to tensor - img_array = np.array(image).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - - # Create corresponding mask - mask_tensor = torch.ones((1, img_array.shape[0], img_array.shape[1]), - dtype=torch.float32, device="cpu") - - # Add to results - top_results_images.append(img_tensor) - top_results_masks.append(mask_tensor) - - # Store result info - result_info.append({ - "doc_id": doc_id, - "page_num": page_num, - "score": float(result.score) if hasattr(result, 'score') else 0.0, - "pdf_name": doc_id_map[doc_id], - "metadata": result.metadata if hasattr(result, 'metadata') else {} - }) - - except Exception as e: - logger.error(f"Error processing image for doc {doc_id}, page {page_num}: {e}") - continue - - if top_results_images: - # Combine tensors - try: - top_combined_images = torch.cat(top_results_images, dim=0) - top_combined_masks = torch.cat(top_results_masks, dim=0) - - # Add debug logging for tensor shapes - logger.debug(f"Number of images processed: {len(top_results_images)}") - logger.debug(f"Combined images tensor shape: {top_combined_images.shape}") - logger.debug(f"Combined masks tensor shape: {top_combined_masks.shape}") - - if llm_provider.lower() not in ["transformers", "openai", "anthropic"]: - # Verify we only have one image - assert top_combined_images.shape[0] == 1, "Expected only one image for non-standard LLM provider" - assert top_combined_masks.shape[0] == 1, "Expected only one mask for non-standard LLM provider" - logger.debug("Confirmed single image output for non-standard LLM provider") - - logger.debug("Final result order:") - for info in result_info: - logger.debug(f"Doc: {info['doc_id']} ({info['pdf_name']}), Page: {info['page_num']}, Score: {info['score']:.2f}") - - return top_combined_images, top_combined_masks, result_info - - except Exception as e: - logger.error(f"Error combining tensors: {e}") - return None, None, [] - - logger.debug("No valid images to process") - return None, None, [] - - except Exception as e: - logger.error(f"Error in get_top_results: {str(e)}") - return None, None, [] - - async def query(self, prompt: str, query_type: str, system_message_str: str, **kwargs): - try: - # 1. Initialize settings and parameters - settings: dict = self.graphrag_app.load_settings() - llm_provider: str = kwargs.pop("llm_provider", settings.get("llm_provider", "ollama")) - base_ip: str = kwargs.pop("base_ip", settings.get("base_ip", "localhost")) - port: str = kwargs.pop("port", settings.get("port", "11434")) - llm_model: str = kwargs.pop('llm_model', settings.get("llm_model", "llama3.1:latest")) - llm_api_key: str = settings.get('external_llm_api_key') if settings.get('external_llm_api_key') != "" else get_api_key(f"{settings['llm_provider'].upper()}_API_KEY", settings['llm_provider']) - keep_alive: str = kwargs.pop("keep_alive", settings.get("keep_alive", "False")) - seed: str = kwargs.pop("seed", settings.get("seed", "None")) - temperature: float = float(kwargs.pop("temperature", settings.get("temperature", "0.7"))) - top_p: float = float(kwargs.pop("top_p", settings.get("top_p", "0.90"))) - top_k: int = int(kwargs.pop("top_k", settings.get("top_k", "40"))) - max_tokens: int = int(kwargs.pop("max_tokens", settings.get("max_tokens", "2048"))) - presence_penalty: float = float(kwargs.pop("repeat_penalty", settings.get("repeat_penalty", "1.2"))) - random: str = kwargs.pop("random", settings.get("random", "False")) - stop: str = kwargs.pop("stop", settings.get("stop", "None")) - precision: str = kwargs.pop("precision", settings.get("precision", "fp16")) - attention: str = kwargs.pop("attention", settings.get("attention", "sdpa")) - index_name: str = kwargs.pop("rag_folder_name", settings.get("rag_folder_name")) - prime_directives: str = kwargs.pop("prime_directives", settings.get("prime_directives", "None")) - aspect_ratio: str = kwargs.pop("aspect_ratio", settings.get("aspect_ratio", "16:9")) - top_k_search: str = kwargs.pop("top_k_search", settings.get("top_k_search", "3")) - #vertical/horizontal is x/y on pdf2image that is why I inverted the aspect ratio - size: tuple[int, int] = (768, 1024) if aspect_ratio == "16:9" else (1024, 768) if aspect_ratio == "9:16" else (1024, 1024) - messages: list = [] - if prime_directives != "None": - system_message_str = prime_directives - elif system_message_str == "None": - system_message_str = "You are a helpful assistant. Analyze the image and answer the user's question." - - # 2. Get and validate model - colpali_model = await self._prepare_model(query_type, index_name) - if not colpali_model: - return self._create_error_response(prompt, "ColPali model not available") - - # 3. Perform search and validate results - top_k_search_int: int = int(top_k_search) - results: Union[List[str], List[Any]] = await self._perform_search( - colpali_model, - str(prompt), - top_k_search_int, - llm_provider=llm_provider - ) - if not results: - return self._create_error_response(prompt, "No relevant documents found", tool_output="Search returned no results") - - # 4. Process images - image_data: Optional[Union[ - Tuple[torch.Tensor, torch.Tensor, List[Dict]], - Tuple[torch.Tensor, torch.Tensor, List[str]] - ]] = await self._process_images( - results, - index_name, - size, - llm_provider=llm_provider - ) - if not image_data: - return self._create_error_response(prompt, "Failed to process images") - - images_tensor: torch.Tensor - masks_tensor: torch.Tensor - result_info: Union[List[Dict], List[str]] - images_tensor, masks_tensor, result_info = image_data - - logger.debug(f"Images tensor shape: {images_tensor.shape}") - logger.debug(f"Result info: {result_info}") - - try: - generated_text: Union[str, List[str]] - generated_text = await send_request( - llm_provider=llm_provider, - base_ip=base_ip, - port=port, - images=images_tensor, - llm_model=llm_model, - system_message=system_message_str, - user_message=prompt, - messages=messages, - seed=seed, - temperature=temperature, - max_tokens=max_tokens, - random=random, - top_k=top_k, - top_p=top_p, - repeat_penalty=presence_penalty, - stop=stop, - keep_alive=keep_alive, - llm_api_key=llm_api_key if llm_api_key != "" else None, - precision=precision, - attention=attention, - ) - - # Handle case where generated_text is a list - if isinstance(generated_text, list): - generated_text = "\n".join(generated_text) # Join with newlines to preserve formatting - - except Exception as e: - logger.error(f"Error in API request: {str(e)}") - return { - "Question": prompt, - "Response": f"Error communicating with {llm_provider}: {str(e)}", - "Negative": "", - "Tool_Output": str(result_info), - "Retrieved_Image": images_tensor.detach() if torch.is_tensor(images_tensor) else None, - "Mask": masks_tensor.detach() if torch.is_tensor(masks_tensor) else None - } - - # 6. Format and return response - return { - "Question": prompt, - "Response": generated_text, - "Negative": "", - "Tool_Output": self._format_tool_output(result_info), - "Retrieved_Image": images_tensor, - "Mask": masks_tensor - } - - except Exception as e: - logger.error(f"Error in colpali query: {str(e)}") - return self._create_error_response(prompt, f"Error processing query: {str(e)}") - - async def _prepare_model(self, query_type, index_name): - """Prepare and validate the ColPali model""" - try: - # Check if we already have the correct index loaded - if self.cached_index_model is not None and self.cached_index_name == index_name: - logger.debug(f"Using cached index model for {index_name}") - return self.cached_index_model - - # Get base model - colpali_model = self.get_colpali_model(query_type) - if not colpali_model: - logger.error("Failed to get base ColPali model") - return None - - # Load new index - logger.debug(f"Loading new index {index_name} from .byaldi...") - try: - model = RAGMultiModalModel.from_index( - index_name, - index_root=".byaldi", - device="cuda", - verbose=1 - ) - - # Verify index loaded correctly - if not hasattr(model.model, 'indexed_embeddings') or not model.model.indexed_embeddings: - logger.error("Index loaded but no embeddings found") - return None - - # Cache the successfully loaded index - self.cached_index_model = model - self.cached_index_name = index_name - - logger.debug(f"Successfully loaded and cached index with {len(model.model.indexed_embeddings)} embeddings") - return model - - except Exception as e: - logger.error(f"Error loading index: {str(e)}") - return None - - except Exception as e: - logger.error(f"Error in _prepare_model: {str(e)}") - return None - - async def _perform_search(self, model, prompt, top_k_search, llm_provider="transformers"): - """ - Perform search and validate results - - Args: - model: The RAG model to use for search - prompt: The search prompt - top_k_search: Number of results to return - llm_provider: The LLM provider being used (default: "transformers") - """ - logger.debug(f"Searching with prompt: {prompt}") - - # Adjust top_k based on llm_provider - if llm_provider.lower() not in ["transformers", "openai", "anthropic"]: - top_k_search = 1 - - results = model.search(prompt, k=top_k_search) - - if results: - for result in results: - if hasattr(result, 'doc_id'): - result.doc_id = int(result.doc_id) if isinstance(result.doc_id, str) else result.doc_id - if hasattr(result, 'page_num'): - result.page_num = int(result.page_num) if isinstance(result.page_num, str) else result.page_num - return results - - async def _process_images(self, results, index_name, size, llm_provider="transformers"): - """Process and validate images from search results""" - all_images = self.load_indexed_images(index_name, size) - if all_images is None: - return None - - return self.get_top_results(results, all_images, index_name, llm_provider) - - def _format_tool_output(self, result_info): - """Format tool output text""" - tool_text = "Retrieved Documents:\n" - for info in result_info: - tool_text += f"\nDocument {info['doc_id']}, Page {info['page_num']}" - tool_text += f"\nRelevance Score: {info['score']:.2f}" - if info['metadata']: - tool_text += f"\nMetadata: {info['metadata']}" - tool_text += "\n" - return tool_text - - def _create_error_response(self, prompt, error_message, tool_output=None): - """Create standardized error response""" - return { - "Question": prompt, - "Response": error_message, - "Negative": "", - "Tool_Output": tool_output, - "Retrieved_Image": None, - "Mask": None - } - - def cleanup(self): - """Cleanup method to free up GPU memory when needed""" - if builtins.global_colpali_model: - del builtins.global_colpali_model - builtins.global_colpali_model = None - torch.cuda.empty_cache() - logger.info("Cleaned up models and freed GPU memory") - - def cleanup_index(self): - """Method to manually clear the cached index if needed""" - self.cached_index_model = None - self.cached_index_name = None - torch.cuda.empty_cache() - logger.info("Cleared cached index and freed GPU memory") - - async def _generate_text_response(self, images_tensor, prompt, system_message_str, params): - """Generate text response based on provider""" +import os +import logging +import torch +import builtins +from byaldi import RAGMultiModalModel +from .graphRAG_module import GraphRAGapp +from typing import Tuple, Optional, Dict, Union, List, Any +from pathlib import Path +import numpy as np +from PIL import Image +from io import BytesIO +import base64 +from .send_request import send_request +import asyncio +import json +import shutil +from PIL import Image +from io import BytesIO +from pdf2image import convert_from_path +from .utils import get_api_key, load_placeholder_image + +import comfy.model_management as mm +from comfy.utils import ProgressBar +import folder_paths +from .transformers_api import TransformersModelManager + +import sys + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Global variable for model caching +if not hasattr(builtins, 'global_colpali_model'): + builtins.global_colpali_model = None + +class colpaliRAGapp: + def __init__(self): + self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "rag") + self._rag_root_dir = None + self._input_dir = None + self.graphrag_app = GraphRAGapp() + self.transformers_api = TransformersModelManager() + self.cached_index_model = None # Add this line to store the loaded index + self.cached_index_name = None # Add this line to track which index is loaded + + @property + def rag_root_dir(self): + return self._rag_root_dir + + @rag_root_dir.setter + def rag_root_dir(self, value): + self._rag_root_dir = value + self._input_dir = os.path.join(self.rag_dir, value, "input") if value else None + logger.debug(f"rag_root_dir setter: set to {self._rag_root_dir}") + logger.debug(f"input_dir set to {self._input_dir}") + if self._input_dir: + os.makedirs(self._input_dir, exist_ok=True) + logger.debug(f"Created input directory: {self._input_dir}") + + def set_rag_root_dir(self, rag_folder_name): + if rag_folder_name: + new_rag_root_dir = os.path.join(self.rag_dir, rag_folder_name) + else: + new_rag_root_dir = os.path.join(self.rag_dir, "rag_data") + + self._rag_root_dir = rag_folder_name # Ensure rag_folder_name is a string + self._input_dir = os.path.join(new_rag_root_dir, "input") + + # Ensure directories exist + os.makedirs(self._rag_root_dir, exist_ok=True) + os.makedirs(self._input_dir, exist_ok=True) + + logger.debug(f"set_rag_root_dir: rag_root_dir set to {self._rag_root_dir}") + logger.debug(f"set_rag_root_dir: input_dir set to {self._input_dir}") + return self._rag_root_dir, self._input_dir + + @classmethod + def get_colpali_model(cls, query_type): + logger.debug(f"Attempting to get ColPali model. Current state: {'Loaded' if builtins.global_colpali_model is not None else 'Not loaded'}") + + if query_type == "colqwen2": + model_path = os.path.join(folder_paths.models_dir, "LLM", "colqwen2-v0.1") + elif query_type == "colpali-v1.2": + model_path = os.path.join(folder_paths.models_dir, "LLM", "colpali-v1.2") + elif query_type == "colpali": + model_path = os.path.join(folder_paths.models_dir, "LLM", "colpali") + else: + logger.error(f"Invalid query type: {query_type}") + return None + + if builtins.global_colpali_model is None: + torch.cuda.empty_cache() + logger.info(f"Loading ColPali model from {model_path}") + try: + builtins.global_colpali_model = RAGMultiModalModel.from_pretrained( + model_path, + device="cuda", + verbose=1, + ) + logger.info("ColPali model loaded and cached globally on CUDA") + except Exception as e: + logger.error(f"Error loading ColPali model: {str(e)}") + builtins.global_colpali_model = None + else: + logger.info("Using existing globally cached ColPali model") + return builtins.global_colpali_model + + async def insert(self): + try: + logger.debug("Starting insert function...") + settings = self.graphrag_app.load_settings() + + index_path = os.path.join(self.rag_dir, settings.get("rag_folder_name"), "input") + index_name = settings.get("rag_folder_name") + query_type = settings.get("query_type", "colqwen") + + # Get the model + colpali_model = self.get_colpali_model(query_type) + if colpali_model is None: + logger.error("Failed to load ColPali model for indexing") + return False + + # Check if index already exists + index_folder = os.path.join(".byaldi", index_name) + is_existing_index = os.path.exists(index_folder) + + # Patch the processor's process_images method + original_process_images = colpali_model.model.processor.process_images + + def process_images_wrapper(*args, **kwargs): + result = original_process_images(*args, **kwargs) + processed = {} + for k, v in result.items(): + if torch.is_tensor(v): + if 'pixel_values' in k: + processed[k] = v.to(dtype=torch.bfloat16) + else: + processed[k] = v.to(dtype=torch.long) + else: + processed[k] = v + return processed + + try: + # Apply the patch + colpali_model.model.processor.process_images = process_images_wrapper + + # Create new index or add to existing one + if is_existing_index: + logger.info(f"Found existing index: {index_name}, loading it first...") + # Load existing index + colpali_model = self.get_colpali_model(query_type) + colpali_model = colpali_model.from_index( + index_name, + index_root=".byaldi", + device="cuda", + verbose=1 + ) + + # Get list of already indexed files + indexed_files = set() + if hasattr(colpali_model.model, 'doc_ids_to_file_names'): + indexed_files = set(colpali_model.model.doc_ids_to_file_names.values()) + + # Process new files only + new_files = [] + for file in os.listdir(index_path): + file_path = os.path.join(index_path, file) + if file_path not in indexed_files: + new_files.append(file_path) + + if new_files: + logger.info(f"Adding {len(new_files)} new documents to existing index") + colpali_model.index( + input_path=new_files, + index_name=index_name, + store_collection_with_index=False, + overwrite=False, + max_image_width=1024, + max_image_height=1024, + ) + else: + logger.info("No new documents to add to the index") + + else: + logger.info(f"Creating new index: {index_name}") + colpali_model.index( + input_path=index_path, + index_name=index_name, + store_collection_with_index=False, + overwrite=False, + max_image_width=1024, + max_image_height=1024, + ) + + # Remove extra folder if it exists + extra_folder = os.path.join(self.comfy_dir, settings.get("rag_folder_name")) + if os.path.exists(extra_folder) and os.path.isdir(extra_folder): + logger.debug(f"Removing extra folder: {extra_folder}") + shutil.rmtree(extra_folder) + + return True + + finally: + # Restore original process_images method + colpali_model.model.processor.process_images = original_process_images + + except Exception as e: + logger.error(f"Error during indexing: {str(e)}") + return False + + def load_indexed_images(self, index_name, size): + """Load previously indexed images from the stored index and convert new PDFs""" + try: + index_path = os.path.join(self.rag_dir, index_name, "input") + images_path = os.path.join(self.rag_dir, index_name, "converted_images") + os.makedirs(images_path, exist_ok=True) + + # Get sorted list of PDF files to ensure consistent ordering + pdf_files = sorted([f for f in os.listdir(index_path) if f.lower().endswith('.pdf')]) + + if not pdf_files: + logger.warning("No PDF files found in the input directory.") + return None + + all_images = {} + + # Process each PDF file + for doc_id, pdf_file in enumerate(pdf_files): + pdf_name = os.path.splitext(pdf_file)[0] + pdf_images_dir = os.path.join(images_path, pdf_name) + pdf_path = os.path.join(index_path, pdf_file) + + logger.debug(f"Processing PDF {pdf_file} as doc_id {doc_id}") + + # Check if conversion is needed + needs_conversion = True + if os.path.exists(pdf_images_dir) and os.listdir(pdf_images_dir): + pdf_mtime = os.path.getmtime(pdf_path) + newest_image = max( + os.path.getmtime(os.path.join(pdf_images_dir, f)) + for f in os.listdir(pdf_images_dir) + if f.endswith('.png') + ) + needs_conversion = pdf_mtime > newest_image + + if needs_conversion: + logger.info(f"Converting PDF: {pdf_file}") + os.makedirs(pdf_images_dir, exist_ok=True) + images = convert_from_path( + pdf_path, + thread_count=os.cpu_count() - 1, + fmt='png', + paths_only=False, + size=size, + ) + + # Save converted images with consistent naming + for page_num, img in enumerate(images, 1): + img_path = os.path.join(pdf_images_dir, f"page_{page_num:03d}.png") + img.save(img_path, "PNG") + logger.debug(f"Saved {img_path}") + + # Load images in correct order + page_files = sorted( + [f for f in os.listdir(pdf_images_dir) if f.endswith('.png')], + key=lambda x: int(x.split('_')[1].split('.')[0]) # Sort by page number + ) + + images = [] + for page_file in page_files: + img_path = os.path.join(pdf_images_dir, page_file) + try: + img = Image.open(img_path) + if img.mode != 'RGB': + img = img.convert('RGB') + images.append(img) + logger.debug(f"Loaded {img_path}") + except Exception as e: + logger.error(f"Error loading {img_path}: {e}") + continue + + all_images[doc_id] = images + logger.debug(f"Loaded {len(images)} pages for document {doc_id} ({pdf_file})") + + return all_images + + except Exception as e: + logger.error(f"Error loading indexed images: {str(e)}") + return None + + def get_top_results(self, results, all_images, index_name, llm_provider="transformers"): + """ + Get relevant images and their metadata based on search results. + Returns lists of images and masks along with result information. + + Args: + results: Search results to process + all_images: Dictionary of loaded images + index_name: Name of the index being used + llm_provider: The LLM provider being used (default: "transformers") + """ + top_results_images = [] + top_results_masks = [] + result_info = [] + + try: + # Get list of PDF files + index_path = os.path.join(self.rag_dir, index_name, "input") + pdf_files = sorted([f for f in os.listdir(index_path) if f.lower().endswith('.pdf')]) + + # Sort and filter results with safe score extraction + def get_score(result): + try: + return float(result.score) if hasattr(result, 'score') else 0.0 + except (ValueError, TypeError): + return 0.0 + + sorted_results = sorted(results, key=get_score, reverse=True) + + # For non-standard LLM providers, only take the highest scoring result + if llm_provider.lower() not in ["transformers", "openai", "anthropic"]: + sorted_results = sorted_results[:1] + + logger.debug("Processing sorted results:") + for r in sorted_results: + score = get_score(r) + try: + doc_id = int(r.doc_id) if isinstance(r.doc_id, str) else r.doc_id + page_num = int(r.page_num) if isinstance(r.page_num, str) else r.page_num + logger.debug(f"Doc: {doc_id}, Page: {page_num}, Score: {score}") + except (ValueError, TypeError) as e: + logger.warning(f"Invalid result format: {e}") + continue + + # Create a mapping of original doc_ids to ensure correct ordering + doc_id_map = {i: os.path.splitext(pdf_file)[0] for i, pdf_file in enumerate(pdf_files)} + + for result in sorted_results: + # Updated integer conversion logic + try: + doc_id = int(result.doc_id) if isinstance(result.doc_id, str) else result.doc_id + page_num = int(result.page_num) if isinstance(result.page_num, str) else result.page_num + except (ValueError, TypeError) as e: + logger.warning(f"Invalid doc_id or page_num format: {e}") + continue + + # Validate document ID + if doc_id not in all_images: + logger.warning(f"Document ID {doc_id} not found in loaded images") + continue + + # Validate page number (convert to 0-based index) + page_idx = page_num - 1 + if page_idx < 0 or page_idx >= len(all_images[doc_id]): + logger.warning(f"Invalid page {page_num} for document {doc_id}") + continue + + # Get image for this result + try: + image = all_images[doc_id][page_idx] + logger.debug(f"Retrieved image for doc {doc_id} ('{doc_id_map[doc_id]}'), page {page_num}") + + # Ensure image is in RGB format + if image.mode != 'RGB': + image = image.convert('RGB') + + # Convert to tensor + img_array = np.array(image).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + + # Create corresponding mask + mask_tensor = torch.ones((1, img_array.shape[0], img_array.shape[1]), + dtype=torch.float32, device="cpu") + + # Add to results + top_results_images.append(img_tensor) + top_results_masks.append(mask_tensor) + + # Store result info + result_info.append({ + "doc_id": doc_id, + "page_num": page_num, + "score": float(result.score) if hasattr(result, 'score') else 0.0, + "pdf_name": doc_id_map[doc_id], + "metadata": result.metadata if hasattr(result, 'metadata') else {} + }) + + except Exception as e: + logger.error(f"Error processing image for doc {doc_id}, page {page_num}: {e}") + continue + + if top_results_images: + # Combine tensors + try: + top_combined_images = torch.cat(top_results_images, dim=0) + top_combined_masks = torch.cat(top_results_masks, dim=0) + + # Add debug logging for tensor shapes + logger.debug(f"Number of images processed: {len(top_results_images)}") + logger.debug(f"Combined images tensor shape: {top_combined_images.shape}") + logger.debug(f"Combined masks tensor shape: {top_combined_masks.shape}") + + if llm_provider.lower() not in ["transformers", "openai", "anthropic"]: + # Verify we only have one image + assert top_combined_images.shape[0] == 1, "Expected only one image for non-standard LLM provider" + assert top_combined_masks.shape[0] == 1, "Expected only one mask for non-standard LLM provider" + logger.debug("Confirmed single image output for non-standard LLM provider") + + logger.debug("Final result order:") + for info in result_info: + logger.debug(f"Doc: {info['doc_id']} ({info['pdf_name']}), Page: {info['page_num']}, Score: {info['score']:.2f}") + + return top_combined_images, top_combined_masks, result_info + + except Exception as e: + logger.error(f"Error combining tensors: {e}") + return None, None, [] + + logger.debug("No valid images to process") + return None, None, [] + + except Exception as e: + logger.error(f"Error in get_top_results: {str(e)}") + return None, None, [] + + async def query(self, prompt: str, query_type: str, system_message_str: str, **kwargs): + try: + # 1. Initialize settings and parameters + settings: dict = self.graphrag_app.load_settings() + llm_provider: str = kwargs.pop("llm_provider", settings.get("llm_provider", "ollama")) + base_ip: str = kwargs.pop("base_ip", settings.get("base_ip", "localhost")) + port: str = kwargs.pop("port", settings.get("port", "11434")) + llm_model: str = kwargs.pop('llm_model', settings.get("llm_model", "llama3.1:latest")) + llm_api_key: str = settings.get('external_llm_api_key') if settings.get('external_llm_api_key') != "" else get_api_key(f"{settings['llm_provider'].upper()}_API_KEY", settings['llm_provider']) + keep_alive: str = kwargs.pop("keep_alive", settings.get("keep_alive", "False")) + seed: str = kwargs.pop("seed", settings.get("seed", "None")) + temperature: float = float(kwargs.pop("temperature", settings.get("temperature", "0.7"))) + top_p: float = float(kwargs.pop("top_p", settings.get("top_p", "0.90"))) + top_k: int = int(kwargs.pop("top_k", settings.get("top_k", "40"))) + max_tokens: int = int(kwargs.pop("max_tokens", settings.get("max_tokens", "2048"))) + presence_penalty: float = float(kwargs.pop("repeat_penalty", settings.get("repeat_penalty", "1.2"))) + random: str = kwargs.pop("random", settings.get("random", "False")) + stop: str = kwargs.pop("stop", settings.get("stop", "None")) + precision: str = kwargs.pop("precision", settings.get("precision", "fp16")) + attention: str = kwargs.pop("attention", settings.get("attention", "sdpa")) + index_name: str = kwargs.pop("rag_folder_name", settings.get("rag_folder_name")) + prime_directives: str = kwargs.pop("prime_directives", settings.get("prime_directives", "None")) + aspect_ratio: str = kwargs.pop("aspect_ratio", settings.get("aspect_ratio", "16:9")) + top_k_search: str = kwargs.pop("top_k_search", settings.get("top_k_search", "3")) + #vertical/horizontal is x/y on pdf2image that is why I inverted the aspect ratio + size: tuple[int, int] = (768, 1024) if aspect_ratio == "16:9" else (1024, 768) if aspect_ratio == "9:16" else (1024, 1024) + messages: list = [] + if prime_directives != "None": + system_message_str = prime_directives + elif system_message_str == "None": + system_message_str = "You are a helpful assistant. Analyze the image and answer the user's question." + + # 2. Get and validate model + colpali_model = await self._prepare_model(query_type, index_name) + if not colpali_model: + return self._create_error_response(prompt, "ColPali model not available") + + # 3. Perform search and validate results + top_k_search_int: int = int(top_k_search) + results: Union[List[str], List[Any]] = await self._perform_search( + colpali_model, + str(prompt), + top_k_search_int, + llm_provider=llm_provider + ) + if not results: + return self._create_error_response(prompt, "No relevant documents found", tool_output="Search returned no results") + + # 4. Process images + image_data: Optional[Union[ + Tuple[torch.Tensor, torch.Tensor, List[Dict]], + Tuple[torch.Tensor, torch.Tensor, List[str]] + ]] = await self._process_images( + results, + index_name, + size, + llm_provider=llm_provider + ) + if not image_data: + return self._create_error_response(prompt, "Failed to process images") + + images_tensor: torch.Tensor + masks_tensor: torch.Tensor + result_info: Union[List[Dict], List[str]] + images_tensor, masks_tensor, result_info = image_data + + logger.debug(f"Images tensor shape: {images_tensor.shape}") + logger.debug(f"Result info: {result_info}") + + try: + generated_text: Union[str, List[str]] + generated_text = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=images_tensor, + llm_model=llm_model, + system_message=system_message_str, + user_message=prompt, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=presence_penalty, + stop=stop, + keep_alive=keep_alive, + llm_api_key=llm_api_key if llm_api_key != "" else None, + precision=precision, + attention=attention, + ) + + # Handle case where generated_text is a list + if isinstance(generated_text, list): + generated_text = "\n".join(generated_text) # Join with newlines to preserve formatting + + except Exception as e: + logger.error(f"Error in API request: {str(e)}") + return { + "Question": prompt, + "Response": f"Error communicating with {llm_provider}: {str(e)}", + "Negative": "", + "Tool_Output": str(result_info), + "Retrieved_Image": images_tensor.detach() if torch.is_tensor(images_tensor) else None, + "Mask": masks_tensor.detach() if torch.is_tensor(masks_tensor) else None + } + + # 6. Format and return response + return { + "Question": prompt, + "Response": generated_text, + "Negative": "", + "Tool_Output": self._format_tool_output(result_info), + "Retrieved_Image": images_tensor, + "Mask": masks_tensor + } + + except Exception as e: + logger.error(f"Error in colpali query: {str(e)}") + return self._create_error_response(prompt, f"Error processing query: {str(e)}") + + async def _prepare_model(self, query_type, index_name): + """Prepare and validate the ColPali model""" + try: + # Check if we already have the correct index loaded + if self.cached_index_model is not None and self.cached_index_name == index_name: + logger.debug(f"Using cached index model for {index_name}") + return self.cached_index_model + + # Get base model + colpali_model = self.get_colpali_model(query_type) + if not colpali_model: + logger.error("Failed to get base ColPali model") + return None + + # Load new index + logger.debug(f"Loading new index {index_name} from .byaldi...") + try: + model = RAGMultiModalModel.from_index( + index_name, + index_root=".byaldi", + device="cuda", + verbose=1 + ) + + # Verify index loaded correctly + if not hasattr(model.model, 'indexed_embeddings') or not model.model.indexed_embeddings: + logger.error("Index loaded but no embeddings found") + return None + + # Cache the successfully loaded index + self.cached_index_model = model + self.cached_index_name = index_name + + logger.debug(f"Successfully loaded and cached index with {len(model.model.indexed_embeddings)} embeddings") + return model + + except Exception as e: + logger.error(f"Error loading index: {str(e)}") + return None + + except Exception as e: + logger.error(f"Error in _prepare_model: {str(e)}") + return None + + async def _perform_search(self, model, prompt, top_k_search, llm_provider="transformers"): + """ + Perform search and validate results + + Args: + model: The RAG model to use for search + prompt: The search prompt + top_k_search: Number of results to return + llm_provider: The LLM provider being used (default: "transformers") + """ + logger.debug(f"Searching with prompt: {prompt}") + + # Adjust top_k based on llm_provider + if llm_provider.lower() not in ["transformers", "openai", "anthropic"]: + top_k_search = 1 + + results = model.search(prompt, k=top_k_search) + + if results: + for result in results: + if hasattr(result, 'doc_id'): + result.doc_id = int(result.doc_id) if isinstance(result.doc_id, str) else result.doc_id + if hasattr(result, 'page_num'): + result.page_num = int(result.page_num) if isinstance(result.page_num, str) else result.page_num + return results + + async def _process_images(self, results, index_name, size, llm_provider="transformers"): + """Process and validate images from search results""" + all_images = self.load_indexed_images(index_name, size) + if all_images is None: + return None + + return self.get_top_results(results, all_images, index_name, llm_provider) + + def _format_tool_output(self, result_info): + """Format tool output text""" + tool_text = "Retrieved Documents:\n" + for info in result_info: + tool_text += f"\nDocument {info['doc_id']}, Page {info['page_num']}" + tool_text += f"\nRelevance Score: {info['score']:.2f}" + if info['metadata']: + tool_text += f"\nMetadata: {info['metadata']}" + tool_text += "\n" + return tool_text + + def _create_error_response(self, prompt, error_message, tool_output=None): + """Create standardized error response""" + return { + "Question": prompt, + "Response": error_message, + "Negative": "", + "Tool_Output": tool_output, + "Retrieved_Image": None, + "Mask": None + } + + def cleanup(self): + """Cleanup method to free up GPU memory when needed""" + if builtins.global_colpali_model: + del builtins.global_colpali_model + builtins.global_colpali_model = None + torch.cuda.empty_cache() + logger.info("Cleaned up models and freed GPU memory") + + def cleanup_index(self): + """Method to manually clear the cached index if needed""" + self.cached_index_model = None + self.cached_index_name = None + torch.cuda.empty_cache() + logger.info("Cleared cached index and freed GPU memory") + + async def _generate_text_response(self, images_tensor, prompt, system_message_str, params): + """Generate text response based on provider""" diff --git a/graphRAG_module.py b/graphRAG_module.py index 1acb1db..20d2796 100644 --- a/graphRAG_module.py +++ b/graphRAG_module.py @@ -1,487 +1,487 @@ -import os -import re -import sys -import glob -import uuid -import yaml -import time -import json -import queue -import shutil -import asyncio -import logging -import aiohttp -import requests -import importlib -import traceback -import folder_paths -# Set up logging -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -import numpy as np -from nano_graphrag.graphrag import GraphRAG, QueryParam -from nano_graphrag.base import BaseKVStorage -from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs -#from litellm import completion, embedding, text_completion - -from .send_request import create_embedding, send_request - -# Set LiteLLM to be verbose -#from litellm import set_verbose -set_verbose = True - -logging.basicConfig(level=logging.WARNING) -logging.getLogger("nano-graphrag").setLevel(logging.INFO) - -from .utils import get_api_key - - - -from .graph_visualize_tool import visualize_graph - -class GraphRAGapp: - def __init__(self): - self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "rag") - self._rag_root_dir = None - self._input_dir = None - self.embedding_func = None - self.graphrag = None - - @property - def rag_root_dir(self): - return self._rag_root_dir - - @rag_root_dir.setter - def rag_root_dir(self, value): - self._rag_root_dir = value - self._input_dir = os.path.join(value, "input") if value else None - logger.debug(f"rag_root_dir setter: set to {self._rag_root_dir}") - logger.debug(f"input_dir set to {self._input_dir}") - if self._input_dir: - os.makedirs(self._input_dir, exist_ok=True) - logger.debug(f"Created input directory: {self._input_dir}") - - def set_rag_root_dir(self, rag_folder_name): - if rag_folder_name: - new_rag_root_dir = os.path.join(self.rag_dir, rag_folder_name) - else: - new_rag_root_dir = os.path.join(self.rag_dir, "rag_data") - - self._rag_root_dir = new_rag_root_dir - self._input_dir = os.path.join(new_rag_root_dir, "input") - - # Ensure directories exist - os.makedirs(self._rag_root_dir, exist_ok=True) - os.makedirs(self._input_dir, exist_ok=True) - - logger.debug(f"set_rag_root_dir: rag_root_dir set to {self._rag_root_dir}") - logger.debug(f"set_rag_root_dir: input_dir set to {self._input_dir}") - return self._rag_root_dir - - def _save_settings_to_path(self, settings_path): - """Save settings to a specific path, overwriting if it already exists.""" - try: - with open(settings_path, 'w') as f: - yaml.dump( - self.settings, - f, - default_flow_style=False, - sort_keys=False, - allow_unicode=True, - default_style=None, - Dumper=yaml.SafeDumper, - ) - logger.info(f"Settings saved to {settings_path}") - except Exception as e: - logger.error(f"Error saving settings to {settings_path}: {str(e)}") - - def save_settings(self): - """Save settings to both the RAG-specific folder and the main RAG directory.""" - if self.settings_path: - self._save_settings_to_path(self.settings_path) - else: - logger.warning("RAG-specific settings path not set. Unable to save settings to RAG folder.") - - # Save a copy to the main RAG directory - rag_dir_settings_path = os.path.join(self.rag_dir, 'settings.yaml') - self._save_settings_to_path(rag_dir_settings_path) - - return self.settings - - async def setup_and_initialize_folder(self, rag_folder_name, settings): - try: - rag_root = os.path.join(self.rag_dir, rag_folder_name) - logger.debug(f"rag_root set to: {rag_root}") - self.settings_path = os.path.join(rag_root, 'settings.yaml') - - self._rag_root_dir = rag_root - self._input_dir = os.path.join(rag_root, "input") - - os.makedirs(rag_root, exist_ok=True) - logger.info(f"Created/ensured folder: {rag_root}") - - # Create the input directory - os.makedirs(self._input_dir, exist_ok=True) - logger.info(f"Created/ensured input directory: {self._input_dir}") - - # Update settings.yaml with UI settings - self.settings = self._create_settings_from_ui(settings) - self.save_settings() - - # Add a short delay to ensure settings are saved - await asyncio.sleep(1) - - # Create the GraphRAG instance here - await self.setup_embedding_func() - self.graphrag = GraphRAG( - working_dir=self._rag_root_dir, - enable_llm_cache=True, - best_model_func=self.unified_model_if_cache, - cheap_model_func=self.unified_model_if_cache, - embedding_func=self.embedding_func, - ) - - result = { - "status": "success", - "message": f"Folder initialized: {rag_root}", - "rag_root_dir": rag_root, - } - logger.debug(f"Final result: {result}") - logger.debug(f"self.rag_root_dir after initialization: {self.rag_root_dir}") - return result - - except Exception as e: - logger.error(f"Error in setup_and_initialize_folder: {str(e)}") - return {"status": "error", "message": str(e)} - - def _create_settings_from_ui(self, ui_settings): - """ - Create settings.yaml from UI settings with proper type conversion. - """ - settings = { - 'embedding_provider': str(ui_settings.get('embedding_provider', 'sentence_transformers')), - 'embedding_model': str(ui_settings.get('embedding_model', 'avsolatorio/GIST-small-Embedding-v0')), - 'base_ip': str(ui_settings.get('base_ip', 'localhost')), - 'port': str(ui_settings.get('port', '11434')), - 'llm_provider': str(ui_settings.get('llm_provider', 'ollama')), - 'llm_model': str(ui_settings.get('llm_model', 'llama3.1:latest')), - 'temperature': float(ui_settings.get('temperature', '0.7')), - 'max_tokens': int(ui_settings.get('max_tokens', '2048')), - 'stop': None if ui_settings.get('stop', 'None') == 'None' else str(ui_settings.get('stop')), - 'keep_alive': ui_settings.get('keep_alive', 'False').lower() == 'true', # Convert to boolean - 'top_k': int(ui_settings.get('top_k', '40')), - 'top_p': float(ui_settings.get('top_p', '0.90')), - 'repeat_penalty': float(ui_settings.get('repeat_penalty', '1.2')), - 'seed': None if ui_settings.get('seed', 'None') == 'None' else int(ui_settings.get('seed')), - 'rag_folder_name': str(ui_settings.get('rag_folder_name', 'rag_data')), - 'query_type': str(ui_settings.get('query_type', 'global')), - 'community_level': int(ui_settings.get('community_level', '2')), - 'preset': str(ui_settings.get('preset', 'Default')), - 'external_llm_api_key': str(ui_settings.get('external_llm_api_key', '')), - 'random': ui_settings.get('random', 'False').lower() == 'true', # Convert to boolean - 'prime_directives': None if ui_settings.get('prime_directives', 'None') == 'None' else str(ui_settings.get('prime_directives')), - 'prompt': str(ui_settings.get('prompt', 'Who helped Safiro infiltrate the Zaltar Organisation?')), - 'response_format': str(ui_settings.get('response_format', 'json')), - 'precision': str(ui_settings.get('precision', 'fp16')), - 'attention': str(ui_settings.get('attention', 'sdpa')), - 'aspect_ratio': str(ui_settings.get('aspect_ratio', '16:9')), - 'top_k_search': int(ui_settings.get('top_k_search', '3')), - } - return settings - - def load_settings(self): - if self._rag_root_dir: - self.settings_path = os.path.join(self._rag_root_dir, "settings.yaml") - else: - self.settings_path = os.path.join(self.rag_dir, "settings.yaml") - - if os.path.exists(self.settings_path): - with open(self.settings_path, 'r') as f: - try: - self.settings = yaml.safe_load(f) - logger.info(f"Loaded settings from {self.settings_path}") - except yaml.YAMLError as e: - logger.error(f"Error parsing settings file: {str(e)}") - self.settings = {} - else: - logger.warning(f"Settings file not found at {self.settings_path}") - self.settings = {} - - return self.settings - - async def setup_embedding_func(self, **kwargs) -> None: - settings = self.load_settings() - base_ip = kwargs.pop("base_ip", settings.get("base_ip", "localhost")) - port = kwargs.pop("port", settings.get("port", "11434")) - #base64_image = kwargs.pop("base64_image", settings.get("base64_image", None)) - embedding_provider = settings.get('embedding_provider', 'sentence_transformers') - embedding_model = settings.get('embedding_model', 'avsolatorio/GIST-small-Embedding-v0') - - embedding_api_key = settings.get('external_llm_api_key') if settings.get('external_llm_api_key') != "" else get_api_key(f"{embedding_provider.upper()}_API_KEY", embedding_provider) - - api_base = f"http://{base_ip}:{port}" if embedding_provider in ["ollama", "lmstudio", "llamacpp", "textgen"] else f"https://api.{embedding_provider}.com" - - if embedding_provider in ["openai", "mistral", "lmstudio", "llamacpp", "textgen", "ollama"]: - embedding_dim = 1536 if embedding_provider in ["openai", "mistral"] else 768 - @wrap_embedding_func_with_attrs(embedding_dim=embedding_dim, max_token_size=8192) - async def embedding_func(texts: list[str]) -> np.ndarray: - embeddings = [] # Initialize embeddings as a list - - for text in texts: # Iterate through each text in the input list - embedding = await create_embedding( - embedding_provider, api_base, embedding_model, [text], embedding_api_key # Send single text at a time - ) - if embedding is None: - raise ValueError( - f"Failed to generate embeddings with {embedding_provider}/{embedding_model}" - ) - embeddings.append(embedding) # Append individual embedding to list - - return np.array(embeddings) # Convert list of embeddings to NumPy array - - elif embedding_provider == "sentence_transformers": - from sentence_transformers import SentenceTransformer - EMBED_MODEL = SentenceTransformer(embedding_model) - embedding_dim = EMBED_MODEL.get_sentence_embedding_dimension() - max_token_size = EMBED_MODEL.max_seq_length - - @wrap_embedding_func_with_attrs( - embedding_dim=embedding_dim, max_token_size=max_token_size - ) - async def embedding_func(texts: list[str]) -> np.ndarray: - return EMBED_MODEL.encode(texts, normalize_embeddings=True) - - self.embedding_func = embedding_func - - def remove_if_exist(self, file): - if os.path.exists(file): - os.remove(file) - - - async def unified_model_if_cache(self, prompt, system_prompt=None, history_messages=[], **kwargs) -> str: - settings = self.load_settings() - logger.info(f"Loaded settings for LLM: {settings}") - base_ip = kwargs.pop("base_ip", settings.get("base_ip", "localhost")) - port = kwargs.pop("port", settings.get("port", "11434")) - llm_provider = kwargs.pop("llm_provider", settings.get("llm_provider", "ollama")) - llm_model = kwargs.pop("llm_model", settings.get("llm_model", "llama3.2:latest")) - temperature = float(kwargs.pop("temperature", settings.get("temperature", "0.7"))) - max_tokens = int(kwargs.pop("max_tokens", settings.get("max_tokens", "2048"))) - keep_alive = kwargs.pop("keep_alive", settings.get("keep_alive", "False")) - top_k = int(kwargs.pop("top_k", settings.get("top_k", "50"))) - top_p = float(kwargs.pop("top_p", settings.get("top_p", "0.95"))) - presence_penalty = float(kwargs.pop("repeat_penalty", settings.get("repeat_penalty", "1.2"))) - llm_api_key = settings.get('external_llm_api_key') if settings.get('external_llm_api_key') != "" else get_api_key(f"{settings['llm_provider'].upper()}_API_KEY", settings['llm_provider']) - seed = kwargs.pop("seed", settings.get("seed", "None")) - random = kwargs.pop("random", settings.get("random", "False")) - response_format = kwargs.pop("response_format", settings.get("response_format", "json")) - stop = kwargs.pop("stop", settings.get("stop", "None")) - if stop is None or stop.lower() == "none": - stop = None - - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - - if hashing_kv is not None: - args_hash = compute_args_hash(llm_model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] - - print(f"Prompt: {prompt}") - user_message = prompt - try: - response = await send_request( - llm_provider=llm_provider, - base_ip=base_ip, - port=port, - images=None, - llm_model=llm_model, - system_message=system_prompt, - user_message=user_message, - messages=messages, - llm_api_key=llm_api_key, - seed=seed, - random=random, - stop=stop, - keep_alive=keep_alive, - temperature=temperature, - top_p=top_p, - top_k=top_k, - max_tokens=max_tokens, - repeat_penalty=presence_penalty, - tools=None, - tool_choice=None - ) - - # Handle different response formats - if isinstance(response, dict) and "message" in response: - result = response["message"]["content"] - elif isinstance(response, dict) and "content" in response: - result = response["content"] - elif isinstance(response, str): - result = response - else: - raise ValueError(f"Unexpected response format: {type(response)}") - - if hashing_kv is not None: - await hashing_kv.upsert({args_hash: {"return": result, "model": llm_model}}) - return result - - except Exception as e: - logger.error(f"Error during LLM completion: {str(e)}") - logger.error(f"Response type: {type(response)}") - logger.error(f"Response content: {response}") - raise ValueError(f"Error during LLM completion: {str(e)}") - - def get_preset_values(self, preset, kwargs, settings): - preset_values = { - "Default": ("2", "Multiple Paragraphs"), - "Detailed": ("4", "Multi-Page Report"), - "Quick": ("1", "Single Paragraph"), - "Bullet": ("2", "List of 3-7 Points"), - "Comprehensive": ("5", "Multi-Page Report"), - "High-Level": ("1", "Single Page"), - "Focused": ("3", "Multiple Paragraphs"), - } - - if preset.startswith(tuple(preset_values.keys())): - return preset_values[preset.split()[0]] - elif preset == "Custom Query": - return ( - kwargs.pop("community_level", settings.get("community_level", "2")), - kwargs.pop("response_type", settings.get("response_type", "Multiple Paragraphs")) - ) - else: - return ("2", "Multiple Paragraphs") - - async def query(self, prompt, query_type, preset): - logger.debug(f"Query - GraphRAG instance id: {id(self.graphrag)}") - logger.debug(f"Query - Working directory: {self._rag_root_dir}") - - settings = self.load_settings() - working_dir = os.path.join(self.rag_dir, settings.get("rag_folder_name")) - print(f"Working directory: {working_dir}") - - if self.graphrag is None: - logger.info("GraphRAG instance not initialized. Initializing...") - await self.setup_embedding_func() - self.graphrag = GraphRAG( - working_dir=working_dir, - enable_llm_cache=True, - best_model_func=self.unified_model_if_cache, - cheap_model_func=self.unified_model_if_cache, - embedding_func=self.embedding_func, - ) - - community_level, response_type = self.get_preset_values(preset, {}, settings) - print(f"Community level: {community_level}, Response type: {response_type}") - - for filename in ["vdb_entities.json", "kv_store_full_docs.json", "kv_store_text_chunks.json"]: - file_path = os.path.join(working_dir, filename) - if os.path.exists(file_path): - logger.debug(f"File exists: {file_path}, size: {os.path.getsize(file_path)} bytes") - else: - logger.warning(f"File not found: {file_path}") - - try: - result = await self.graphrag.aquery( - query=prompt, - param=QueryParam( - mode=query_type, - response_type=response_type, - level=int(community_level) - ) - ) - - # Define the dynamic path for the GraphML file - graphml_path = os.path.join(self.graphrag.working_dir, "graph_chunk_entity_relation.graphml") - - # Call the visualize_graph function to visualize the graph - try: - visualize_graph(graphml_path) - except Exception as viz_error: - logger.error(f"Error visualizing graph: {str(viz_error)}") - print(f"Error visualizing graph: {str(viz_error)}") - - - return result, graphml_path - except Exception as e: - logger.error(f"Error in GraphRAGapp.query: {str(e)}") - logger.error(traceback.format_exc()) - return f"Error during query: {str(e)}" - - async def insert(self): - logger.debug("Starting insert function...") - logger.debug(f"Insert - rag_dir: {self.rag_dir}") - logger.debug(f"Insert - _rag_root_dir: {self._rag_root_dir}") - logger.debug(f"Insert - _input_dir: {self._input_dir}") - settings = self.load_settings() - print(f"Settings: {settings}") - - working_dir = self._rag_root_dir - insert_input_dir = self._input_dir - - print(f"Working directory: {working_dir}") - print(f"Insert input directory: {insert_input_dir}") - try: - logger.debug(f"Listing files in {insert_input_dir}") - all_texts = [] - for filename in os.listdir(insert_input_dir): - if filename.endswith(".txt"): - file_path = os.path.join(insert_input_dir, filename) - logger.debug(f"Reading file: {file_path}") - with open(file_path, encoding="utf-8-sig") as f: - all_texts.append(f.read()) - - if not all_texts: - logger.warning("No text files found in the input directory.") - return False - - combined_text = "\n".join(all_texts) - logger.debug(f"Combined text length: {len(combined_text)}") - - # Remove existing files - logger.debug("Removing existing files...") - for filename in ["vdb_entities.json", "kv_store_full_docs.json", "kv_store_text_chunks.json", "kv_store_community_reports.json", "graph_chunk_entity_relation.graphml"]: - self.remove_if_exist(os.path.join(working_dir, filename)) - - logger.debug("Creating GraphRAG instance...") - - # Set up the embedding function before creating the GraphRAG instance - await self.setup_embedding_func() - - # Use the existing graphrag instance or create a new one - if self.graphrag is None: - self.graphrag = GraphRAG( - working_dir=working_dir, - enable_llm_cache=True, - best_model_func=self.unified_model_if_cache, - cheap_model_func=self.unified_model_if_cache, - embedding_func=self.embedding_func, - ) - - start = time.time() - logger.debug("Inserting text...") - await self.graphrag.ainsert(combined_text) - logger.debug(f"Indexing completed in {time.time() - start:.2f} seconds") - print("indexing time:", time.time() - start) - - # Cleanup step - extra_folder = os.path.join(self.comfy_dir, settings.get("rag_folder_name")) - if os.path.exists(extra_folder) and os.path.isdir(extra_folder): - logger.debug(f"Removing extra folder: {extra_folder}") - shutil.rmtree(extra_folder) - - return True - - except Exception as e: - logger.error(f"Error during indexing: {str(e)}") - logger.error(traceback.format_exc()) +import os +import re +import sys +import glob +import uuid +import yaml +import time +import json +import queue +import shutil +import asyncio +import logging +import aiohttp +import requests +import importlib +import traceback +import folder_paths +# Set up logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +import numpy as np +from nano_graphrag.graphrag import GraphRAG, QueryParam +from nano_graphrag.base import BaseKVStorage +from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs +#from litellm import completion, embedding, text_completion + +from .send_request import create_embedding, send_request + +# Set LiteLLM to be verbose +#from litellm import set_verbose +set_verbose = True + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("nano-graphrag").setLevel(logging.INFO) + +from .utils import get_api_key + + + +from .graph_visualize_tool import visualize_graph + +class GraphRAGapp: + def __init__(self): + self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "rag") + self._rag_root_dir = None + self._input_dir = None + self.embedding_func = None + self.graphrag = None + + @property + def rag_root_dir(self): + return self._rag_root_dir + + @rag_root_dir.setter + def rag_root_dir(self, value): + self._rag_root_dir = value + self._input_dir = os.path.join(value, "input") if value else None + logger.debug(f"rag_root_dir setter: set to {self._rag_root_dir}") + logger.debug(f"input_dir set to {self._input_dir}") + if self._input_dir: + os.makedirs(self._input_dir, exist_ok=True) + logger.debug(f"Created input directory: {self._input_dir}") + + def set_rag_root_dir(self, rag_folder_name): + if rag_folder_name: + new_rag_root_dir = os.path.join(self.rag_dir, rag_folder_name) + else: + new_rag_root_dir = os.path.join(self.rag_dir, "rag_data") + + self._rag_root_dir = new_rag_root_dir + self._input_dir = os.path.join(new_rag_root_dir, "input") + + # Ensure directories exist + os.makedirs(self._rag_root_dir, exist_ok=True) + os.makedirs(self._input_dir, exist_ok=True) + + logger.debug(f"set_rag_root_dir: rag_root_dir set to {self._rag_root_dir}") + logger.debug(f"set_rag_root_dir: input_dir set to {self._input_dir}") + return self._rag_root_dir + + def _save_settings_to_path(self, settings_path): + """Save settings to a specific path, overwriting if it already exists.""" + try: + with open(settings_path, 'w') as f: + yaml.dump( + self.settings, + f, + default_flow_style=False, + sort_keys=False, + allow_unicode=True, + default_style=None, + Dumper=yaml.SafeDumper, + ) + logger.info(f"Settings saved to {settings_path}") + except Exception as e: + logger.error(f"Error saving settings to {settings_path}: {str(e)}") + + def save_settings(self): + """Save settings to both the RAG-specific folder and the main RAG directory.""" + if self.settings_path: + self._save_settings_to_path(self.settings_path) + else: + logger.warning("RAG-specific settings path not set. Unable to save settings to RAG folder.") + + # Save a copy to the main RAG directory + rag_dir_settings_path = os.path.join(self.rag_dir, 'settings.yaml') + self._save_settings_to_path(rag_dir_settings_path) + + return self.settings + + async def setup_and_initialize_folder(self, rag_folder_name, settings): + try: + rag_root = os.path.join(self.rag_dir, rag_folder_name) + logger.debug(f"rag_root set to: {rag_root}") + self.settings_path = os.path.join(rag_root, 'settings.yaml') + + self._rag_root_dir = rag_root + self._input_dir = os.path.join(rag_root, "input") + + os.makedirs(rag_root, exist_ok=True) + logger.info(f"Created/ensured folder: {rag_root}") + + # Create the input directory + os.makedirs(self._input_dir, exist_ok=True) + logger.info(f"Created/ensured input directory: {self._input_dir}") + + # Update settings.yaml with UI settings + self.settings = self._create_settings_from_ui(settings) + self.save_settings() + + # Add a short delay to ensure settings are saved + await asyncio.sleep(1) + + # Create the GraphRAG instance here + await self.setup_embedding_func() + self.graphrag = GraphRAG( + working_dir=self._rag_root_dir, + enable_llm_cache=True, + best_model_func=self.unified_model_if_cache, + cheap_model_func=self.unified_model_if_cache, + embedding_func=self.embedding_func, + ) + + result = { + "status": "success", + "message": f"Folder initialized: {rag_root}", + "rag_root_dir": rag_root, + } + logger.debug(f"Final result: {result}") + logger.debug(f"self.rag_root_dir after initialization: {self.rag_root_dir}") + return result + + except Exception as e: + logger.error(f"Error in setup_and_initialize_folder: {str(e)}") + return {"status": "error", "message": str(e)} + + def _create_settings_from_ui(self, ui_settings): + """ + Create settings.yaml from UI settings with proper type conversion. + """ + settings = { + 'embedding_provider': str(ui_settings.get('embedding_provider', 'sentence_transformers')), + 'embedding_model': str(ui_settings.get('embedding_model', 'avsolatorio/GIST-small-Embedding-v0')), + 'base_ip': str(ui_settings.get('base_ip', 'localhost')), + 'port': str(ui_settings.get('port', '11434')), + 'llm_provider': str(ui_settings.get('llm_provider', 'ollama')), + 'llm_model': str(ui_settings.get('llm_model', 'llama3.1:latest')), + 'temperature': float(ui_settings.get('temperature', '0.7')), + 'max_tokens': int(ui_settings.get('max_tokens', '2048')), + 'stop': None if ui_settings.get('stop', 'None') == 'None' else str(ui_settings.get('stop')), + 'keep_alive': ui_settings.get('keep_alive', 'False').lower() == 'true', # Convert to boolean + 'top_k': int(ui_settings.get('top_k', '40')), + 'top_p': float(ui_settings.get('top_p', '0.90')), + 'repeat_penalty': float(ui_settings.get('repeat_penalty', '1.2')), + 'seed': None if ui_settings.get('seed', 'None') == 'None' else int(ui_settings.get('seed')), + 'rag_folder_name': str(ui_settings.get('rag_folder_name', 'rag_data')), + 'query_type': str(ui_settings.get('query_type', 'global')), + 'community_level': int(ui_settings.get('community_level', '2')), + 'preset': str(ui_settings.get('preset', 'Default')), + 'external_llm_api_key': str(ui_settings.get('external_llm_api_key', '')), + 'random': ui_settings.get('random', 'False').lower() == 'true', # Convert to boolean + 'prime_directives': None if ui_settings.get('prime_directives', 'None') == 'None' else str(ui_settings.get('prime_directives')), + 'prompt': str(ui_settings.get('prompt', 'Who helped Safiro infiltrate the Zaltar Organisation?')), + 'response_format': str(ui_settings.get('response_format', 'json')), + 'precision': str(ui_settings.get('precision', 'fp16')), + 'attention': str(ui_settings.get('attention', 'sdpa')), + 'aspect_ratio': str(ui_settings.get('aspect_ratio', '16:9')), + 'top_k_search': int(ui_settings.get('top_k_search', '3')), + } + return settings + + def load_settings(self): + if self._rag_root_dir: + self.settings_path = os.path.join(self._rag_root_dir, "settings.yaml") + else: + self.settings_path = os.path.join(self.rag_dir, "settings.yaml") + + if os.path.exists(self.settings_path): + with open(self.settings_path, 'r') as f: + try: + self.settings = yaml.safe_load(f) + logger.info(f"Loaded settings from {self.settings_path}") + except yaml.YAMLError as e: + logger.error(f"Error parsing settings file: {str(e)}") + self.settings = {} + else: + logger.warning(f"Settings file not found at {self.settings_path}") + self.settings = {} + + return self.settings + + async def setup_embedding_func(self, **kwargs) -> None: + settings = self.load_settings() + base_ip = kwargs.pop("base_ip", settings.get("base_ip", "localhost")) + port = kwargs.pop("port", settings.get("port", "11434")) + #base64_image = kwargs.pop("base64_image", settings.get("base64_image", None)) + embedding_provider = settings.get('embedding_provider', 'sentence_transformers') + embedding_model = settings.get('embedding_model', 'avsolatorio/GIST-small-Embedding-v0') + + embedding_api_key = settings.get('external_llm_api_key') if settings.get('external_llm_api_key') != "" else get_api_key(f"{embedding_provider.upper()}_API_KEY", embedding_provider) + + api_base = f"http://{base_ip}:{port}" if embedding_provider in ["ollama", "lmstudio", "llamacpp", "textgen"] else f"https://api.{embedding_provider}.com" + + if embedding_provider in ["openai", "mistral", "lmstudio", "llamacpp", "textgen", "ollama"]: + embedding_dim = 1536 if embedding_provider in ["openai", "mistral"] else 768 + @wrap_embedding_func_with_attrs(embedding_dim=embedding_dim, max_token_size=8192) + async def embedding_func(texts: list[str]) -> np.ndarray: + embeddings = [] # Initialize embeddings as a list + + for text in texts: # Iterate through each text in the input list + embedding = await create_embedding( + embedding_provider, api_base, embedding_model, [text], embedding_api_key # Send single text at a time + ) + if embedding is None: + raise ValueError( + f"Failed to generate embeddings with {embedding_provider}/{embedding_model}" + ) + embeddings.append(embedding) # Append individual embedding to list + + return np.array(embeddings) # Convert list of embeddings to NumPy array + + elif embedding_provider == "sentence_transformers": + from sentence_transformers import SentenceTransformer + EMBED_MODEL = SentenceTransformer(embedding_model) + embedding_dim = EMBED_MODEL.get_sentence_embedding_dimension() + max_token_size = EMBED_MODEL.max_seq_length + + @wrap_embedding_func_with_attrs( + embedding_dim=embedding_dim, max_token_size=max_token_size + ) + async def embedding_func(texts: list[str]) -> np.ndarray: + return EMBED_MODEL.encode(texts, normalize_embeddings=True) + + self.embedding_func = embedding_func + + def remove_if_exist(self, file): + if os.path.exists(file): + os.remove(file) + + + async def unified_model_if_cache(self, prompt, system_prompt=None, history_messages=[], **kwargs) -> str: + settings = self.load_settings() + logger.info(f"Loaded settings for LLM: {settings}") + base_ip = kwargs.pop("base_ip", settings.get("base_ip", "localhost")) + port = kwargs.pop("port", settings.get("port", "11434")) + llm_provider = kwargs.pop("llm_provider", settings.get("llm_provider", "ollama")) + llm_model = kwargs.pop("llm_model", settings.get("llm_model", "llama3.2:latest")) + temperature = float(kwargs.pop("temperature", settings.get("temperature", "0.7"))) + max_tokens = int(kwargs.pop("max_tokens", settings.get("max_tokens", "2048"))) + keep_alive = kwargs.pop("keep_alive", settings.get("keep_alive", "False")) + top_k = int(kwargs.pop("top_k", settings.get("top_k", "50"))) + top_p = float(kwargs.pop("top_p", settings.get("top_p", "0.95"))) + presence_penalty = float(kwargs.pop("repeat_penalty", settings.get("repeat_penalty", "1.2"))) + llm_api_key = settings.get('external_llm_api_key') if settings.get('external_llm_api_key') != "" else get_api_key(f"{settings['llm_provider'].upper()}_API_KEY", settings['llm_provider']) + seed = kwargs.pop("seed", settings.get("seed", "None")) + random = kwargs.pop("random", settings.get("random", "False")) + response_format = kwargs.pop("response_format", settings.get("response_format", "json")) + stop = kwargs.pop("stop", settings.get("stop", "None")) + if stop is None or stop.lower() == "none": + stop = None + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + if hashing_kv is not None: + args_hash = compute_args_hash(llm_model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + print(f"Prompt: {prompt}") + user_message = prompt + try: + response = await send_request( + llm_provider=llm_provider, + base_ip=base_ip, + port=port, + images=None, + llm_model=llm_model, + system_message=system_prompt, + user_message=user_message, + messages=messages, + llm_api_key=llm_api_key, + seed=seed, + random=random, + stop=stop, + keep_alive=keep_alive, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_tokens=max_tokens, + repeat_penalty=presence_penalty, + tools=None, + tool_choice=None + ) + + # Handle different response formats + if isinstance(response, dict) and "message" in response: + result = response["message"]["content"] + elif isinstance(response, dict) and "content" in response: + result = response["content"] + elif isinstance(response, str): + result = response + else: + raise ValueError(f"Unexpected response format: {type(response)}") + + if hashing_kv is not None: + await hashing_kv.upsert({args_hash: {"return": result, "model": llm_model}}) + return result + + except Exception as e: + logger.error(f"Error during LLM completion: {str(e)}") + logger.error(f"Response type: {type(response)}") + logger.error(f"Response content: {response}") + raise ValueError(f"Error during LLM completion: {str(e)}") + + def get_preset_values(self, preset, kwargs, settings): + preset_values = { + "Default": ("2", "Multiple Paragraphs"), + "Detailed": ("4", "Multi-Page Report"), + "Quick": ("1", "Single Paragraph"), + "Bullet": ("2", "List of 3-7 Points"), + "Comprehensive": ("5", "Multi-Page Report"), + "High-Level": ("1", "Single Page"), + "Focused": ("3", "Multiple Paragraphs"), + } + + if preset.startswith(tuple(preset_values.keys())): + return preset_values[preset.split()[0]] + elif preset == "Custom Query": + return ( + kwargs.pop("community_level", settings.get("community_level", "2")), + kwargs.pop("response_type", settings.get("response_type", "Multiple Paragraphs")) + ) + else: + return ("2", "Multiple Paragraphs") + + async def query(self, prompt, query_type, preset): + logger.debug(f"Query - GraphRAG instance id: {id(self.graphrag)}") + logger.debug(f"Query - Working directory: {self._rag_root_dir}") + + settings = self.load_settings() + working_dir = os.path.join(self.rag_dir, settings.get("rag_folder_name")) + print(f"Working directory: {working_dir}") + + if self.graphrag is None: + logger.info("GraphRAG instance not initialized. Initializing...") + await self.setup_embedding_func() + self.graphrag = GraphRAG( + working_dir=working_dir, + enable_llm_cache=True, + best_model_func=self.unified_model_if_cache, + cheap_model_func=self.unified_model_if_cache, + embedding_func=self.embedding_func, + ) + + community_level, response_type = self.get_preset_values(preset, {}, settings) + print(f"Community level: {community_level}, Response type: {response_type}") + + for filename in ["vdb_entities.json", "kv_store_full_docs.json", "kv_store_text_chunks.json"]: + file_path = os.path.join(working_dir, filename) + if os.path.exists(file_path): + logger.debug(f"File exists: {file_path}, size: {os.path.getsize(file_path)} bytes") + else: + logger.warning(f"File not found: {file_path}") + + try: + result = await self.graphrag.aquery( + query=prompt, + param=QueryParam( + mode=query_type, + response_type=response_type, + level=int(community_level) + ) + ) + + # Define the dynamic path for the GraphML file + graphml_path = os.path.join(self.graphrag.working_dir, "graph_chunk_entity_relation.graphml") + + # Call the visualize_graph function to visualize the graph + try: + visualize_graph(graphml_path) + except Exception as viz_error: + logger.error(f"Error visualizing graph: {str(viz_error)}") + print(f"Error visualizing graph: {str(viz_error)}") + + + return result, graphml_path + except Exception as e: + logger.error(f"Error in GraphRAGapp.query: {str(e)}") + logger.error(traceback.format_exc()) + return f"Error during query: {str(e)}" + + async def insert(self): + logger.debug("Starting insert function...") + logger.debug(f"Insert - rag_dir: {self.rag_dir}") + logger.debug(f"Insert - _rag_root_dir: {self._rag_root_dir}") + logger.debug(f"Insert - _input_dir: {self._input_dir}") + settings = self.load_settings() + print(f"Settings: {settings}") + + working_dir = self._rag_root_dir + insert_input_dir = self._input_dir + + print(f"Working directory: {working_dir}") + print(f"Insert input directory: {insert_input_dir}") + try: + logger.debug(f"Listing files in {insert_input_dir}") + all_texts = [] + for filename in os.listdir(insert_input_dir): + if filename.endswith(".txt"): + file_path = os.path.join(insert_input_dir, filename) + logger.debug(f"Reading file: {file_path}") + with open(file_path, encoding="utf-8-sig") as f: + all_texts.append(f.read()) + + if not all_texts: + logger.warning("No text files found in the input directory.") + return False + + combined_text = "\n".join(all_texts) + logger.debug(f"Combined text length: {len(combined_text)}") + + # Remove existing files + logger.debug("Removing existing files...") + for filename in ["vdb_entities.json", "kv_store_full_docs.json", "kv_store_text_chunks.json", "kv_store_community_reports.json", "graph_chunk_entity_relation.graphml"]: + self.remove_if_exist(os.path.join(working_dir, filename)) + + logger.debug("Creating GraphRAG instance...") + + # Set up the embedding function before creating the GraphRAG instance + await self.setup_embedding_func() + + # Use the existing graphrag instance or create a new one + if self.graphrag is None: + self.graphrag = GraphRAG( + working_dir=working_dir, + enable_llm_cache=True, + best_model_func=self.unified_model_if_cache, + cheap_model_func=self.unified_model_if_cache, + embedding_func=self.embedding_func, + ) + + start = time.time() + logger.debug("Inserting text...") + await self.graphrag.ainsert(combined_text) + logger.debug(f"Indexing completed in {time.time() - start:.2f} seconds") + print("indexing time:", time.time() - start) + + # Cleanup step + extra_folder = os.path.join(self.comfy_dir, settings.get("rag_folder_name")) + if os.path.exists(extra_folder) and os.path.isdir(extra_folder): + logger.debug(f"Removing extra folder: {extra_folder}") + shutil.rmtree(extra_folder) + + return True + + except Exception as e: + logger.error(f"Error during indexing: {str(e)}") + logger.error(traceback.format_exc()) return False \ No newline at end of file diff --git a/graph_visualize_tool.py b/graph_visualize_tool.py index a105b6a..3a8fff6 100644 --- a/graph_visualize_tool.py +++ b/graph_visualize_tool.py @@ -1,292 +1,292 @@ -import networkx as nx -import json -import webbrowser -import os -import http.server -import socketserver -import threading -import subprocess -import sys -import platform -import shutil - -def graphml_to_json(graphml_file): - G = nx.read_graphml(graphml_file) - data = nx.node_link_data(G) - return json.dumps(data) - -def create_html(json_data, html_path): - json_data = json_data.replace('\\"', '') - html_content = ''' - - - - - - Graph Visualization - - - - - -
-
- - - - '''.replace("{json_data}", json_data.replace("'", "\\'").replace("\n", "")) - - with open(html_path, 'w', encoding='utf-8') as f: - f.write(html_content) - -def start_server(port=8189): - handler = http.server.SimpleHTTPRequestHandler - with socketserver.TCPServer(("", port), handler) as httpd: - print(f"Server started at http://localhost:{port}") - httpd.serve_forever() - -def visualize_graph(graphml_file): - if not os.path.exists(graphml_file): - print(f"GraphML file not found: {graphml_file}") - return - - html_path = "graph_visualization.html" - try: - json_data = graphml_to_json(graphml_file) - create_html(json_data, html_path) - except Exception as e: - print(f"Error creating visualization: {str(e)}") - return - - port = 8189 - try: - server_thread = threading.Thread(target=start_server, args=(port,)) - server_thread.daemon = True - server_thread.start() - except OSError as err: - if "Address already in use" in str(err): - print(f"Port {port} is already in use. Trying a different port...") - for new_port in range(port + 1, port + 10): # Try the next few ports - try: - server_thread = threading.Thread(target=start_server, args=(new_port,)) - server_thread.daemon = True - server_thread.start() - port = new_port - print(f"Server started at http://localhost:{port}") - break # Success, exit loop - except OSError as e: - if "Address already in use" not in str(e): - raise # Re-raise if it's not "Address in use" error. - else: # If no ports were available - raise RuntimeError(f"Could not find an available port between {port+1} and {port+9}.") from err - else: - raise # Re-raise the OSError if it's not about the address being in use - - # Open default browser - try: - webbrowser.open(f'http://localhost:{port}/{html_path}') - except Exception as e: - print(f"Error opening browser: {str(e)}") - - print("Graph visualization is ready. The browser should open automatically.") - print(f"If the browser doesn't open, please visit http://localhost:{port}/{html_path}") - print("The server will continue running in the background.") +import networkx as nx +import json +import webbrowser +import os +import http.server +import socketserver +import threading +import subprocess +import sys +import platform +import shutil + +def graphml_to_json(graphml_file): + G = nx.read_graphml(graphml_file) + data = nx.node_link_data(G) + return json.dumps(data) + +def create_html(json_data, html_path): + json_data = json_data.replace('\\"', '') + html_content = ''' + + + + + + Graph Visualization + + + + + +
+
+ + + + '''.replace("{json_data}", json_data.replace("'", "\\'").replace("\n", "")) + + with open(html_path, 'w', encoding='utf-8') as f: + f.write(html_content) + +def start_server(port=8189): + handler = http.server.SimpleHTTPRequestHandler + with socketserver.TCPServer(("", port), handler) as httpd: + print(f"Server started at http://localhost:{port}") + httpd.serve_forever() + +def visualize_graph(graphml_file): + if not os.path.exists(graphml_file): + print(f"GraphML file not found: {graphml_file}") + return + + html_path = "graph_visualization.html" + try: + json_data = graphml_to_json(graphml_file) + create_html(json_data, html_path) + except Exception as e: + print(f"Error creating visualization: {str(e)}") + return + + port = 8189 + try: + server_thread = threading.Thread(target=start_server, args=(port,)) + server_thread.daemon = True + server_thread.start() + except OSError as err: + if "Address already in use" in str(err): + print(f"Port {port} is already in use. Trying a different port...") + for new_port in range(port + 1, port + 10): # Try the next few ports + try: + server_thread = threading.Thread(target=start_server, args=(new_port,)) + server_thread.daemon = True + server_thread.start() + port = new_port + print(f"Server started at http://localhost:{port}") + break # Success, exit loop + except OSError as e: + if "Address already in use" not in str(e): + raise # Re-raise if it's not "Address in use" error. + else: # If no ports were available + raise RuntimeError(f"Could not find an available port between {port+1} and {port+9}.") from err + else: + raise # Re-raise the OSError if it's not about the address being in use + + # Open default browser + try: + webbrowser.open(f'http://localhost:{port}/{html_path}') + except Exception as e: + print(f"Error opening browser: {str(e)}") + + print("Graph visualization is ready. The browser should open automatically.") + print(f"If the browser doesn't open, please visit http://localhost:{port}/{html_path}") + print("The server will continue running in the background.") print("You can close the browser tab when you're done viewing the graph.") \ No newline at end of file diff --git a/graphrag_config.yml b/graphrag_config.yml index fca1fa8..e818d86 100644 --- a/graphrag_config.yml +++ b/graphrag_config.yml @@ -1,13 +1,13 @@ -working_dir: "./graphrag_data" -llm_provider: "ollama" -llm_model: "qwen2" -embedding_provider: "sentence_transformers" -embedding_model: "all-MiniLM-L6-v2" -api_base: "http://localhost:11434" -enable_rag: true -query_type: "local" -community_level: 2 -response_type: "Detailed" -additional_params: - max_tokens: 2048 +working_dir: "./graphrag_data" +llm_provider: "ollama" +llm_model: "qwen2" +embedding_provider: "sentence_transformers" +embedding_model: "all-MiniLM-L6-v2" +api_base: "http://localhost:11434" +enable_rag: true +query_type: "local" +community_level: 2 +response_type: "Detailed" +additional_params: + max_tokens: 2048 temperature: 0.7 \ No newline at end of file diff --git a/lms_api.py b/lms_api.py index 97b3c05..41c1f05 100644 --- a/lms_api.py +++ b/lms_api.py @@ -1,180 +1,180 @@ -#lms_api.py -import requests -import json -from typing import List, Union, Optional -import aiohttp -import asyncio -import logging -logger = logging.getLogger(__name__) - -def create_lmstudio_compatible_embedding(api_base: str, model: str, input: Union[str, List[str]], api_key: Optional[str] = None) -> List[float]: - """ - Create embeddings using an lmstudio-compatible API. - - :param api_base: The base URL for the API - :param model: The name of the model to use for embeddings - :param input: A string or list of strings to embed - :param api_key: The API key (if required) - :return: A list of embeddings - """ - # Normalize the API base URL - api_base = api_base.rstrip('/') - if not api_base.endswith('/v1'): - api_base += '/v1' - - url = f"{api_base}/embeddings" - - headers = { - "Content-Type": "application/json" - } - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - payload = { - "model": model, - "input": input, - "encoding_format": "float" - } - - try: - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - result = response.json() - - if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]: - # If multiple embeddings are returned, we'll just use the first one - return result["data"][0]["embedding"] - else: - raise ValueError("Unexpected response format: 'embedding' data not found") - except requests.RequestException as e: - raise RuntimeError(f"Error calling embedding API: {str(e)}") - -async def send_lmstudio_request(api_url, base64_images, model, system_message, user_message, messages, seed, temperature, - max_tokens, top_k, top_p, repeat_penalty, stop, tools=None, tool_choice=None): - headers = { - "Content-Type": "application/json" - } - - data = { - "model": model, - "messages": prepare_lmstudio_messages(system_message, user_message, messages, base64_images), - "temperature": temperature, - "max_tokens": max_tokens, - "presence_penalty": repeat_penalty, - "top_p": top_p, - "top_k": top_k, - "seed": seed - } - - if stop: - data["stop"] = stop - if tools: - data["functions"] = tools - if tool_choice: - data["function_call"] = tool_choice - - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - response.raise_for_status() - response_data = await response.json() - - choices = response_data.get('choices', []) - if choices: - choice = choices[0] - message = choice.get('message', {}) - if "function_call" in message: - return { - "choices": [{ - "message": { - "function_call": { - "name": message["function_call"]["name"], - "arguments": message["function_call"]["arguments"] - } - } - }] - } - else: - generated_text = message.get('content', '') - return { - "choices": [{ - "message": { - "content": generated_text - } - }] - } - else: - error_msg = "Error: No valid choices in the LMStudio response." - print(error_msg) - return {"choices": [{"message": {"content": error_msg}}]} - except aiohttp.ClientError as e: - error_msg = f"Error in LMStudio API request: {e}" - print(error_msg) - return {"choices": [{"message": {"content": error_msg}}]} - -def prepare_lmstudio_messages(base64_images, system_message, user_message, messages): - lmstudio_messages = [] - - if system_message: - lmstudio_messages.append({"role": "system", "content": system_message}) - - for message in messages: - role = message["role"] - content = message["content"] - - if role == "system": - lmstudio_messages.append({"role": "system", "content": content}) - elif role == "user": - lmstudio_messages.append({"role": "user", "content": content}) - elif role == "assistant": - lmstudio_messages.append({"role": "assistant", "content": content}) - - # Add the current user message with all images if provided - if base64_images: - content = [{"type": "text", "text": user_message}] - for base64_image in base64_images: - content.append({ - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - } - }) - lmstudio_messages.append({ - "role": "user", - "content": content - }) - print(f"Number of images sent: {len(base64_images)}") - else: - lmstudio_messages.append({"role": "user", "content": user_message}) - - return lmstudio_messages - - -"""def prepare_lmstudio_messages(system_message, user_message, messages, base64_images=None): - lmstudio_messages = [ - {"role": "system", "content": system_message}, - ] - - for message in messages: - if isinstance(message["content"], list): - # Handle multi-modal content - content = [] - for item in message["content"]: - if item["type"] == "text": - content.append(item["text"]) - elif item["type"] == "image_url": - content.append(f"[Image data: {item['image_url']['url']}]") - lmstudio_messages.append({"role": message["role"], "content": " ".join(content)}) - else: - lmstudio_messages.append(message) - - if base64_images: - image_content = "\n".join([f"[Image data: data:image/jpeg;base64,{img}]" for img in base64_images]) - lmstudio_messages.append({ - "role": "user", - "content": f"{user_message}\n{image_content}" - }) - else: - lmstudio_messages.append({"role": "user", "content": user_message}) - +#lms_api.py +import requests +import json +from typing import List, Union, Optional +import aiohttp +import asyncio +import logging +logger = logging.getLogger(__name__) + +def create_lmstudio_compatible_embedding(api_base: str, model: str, input: Union[str, List[str]], api_key: Optional[str] = None) -> List[float]: + """ + Create embeddings using an lmstudio-compatible API. + + :param api_base: The base URL for the API + :param model: The name of the model to use for embeddings + :param input: A string or list of strings to embed + :param api_key: The API key (if required) + :return: A list of embeddings + """ + # Normalize the API base URL + api_base = api_base.rstrip('/') + if not api_base.endswith('/v1'): + api_base += '/v1' + + url = f"{api_base}/embeddings" + + headers = { + "Content-Type": "application/json" + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + payload = { + "model": model, + "input": input, + "encoding_format": "float" + } + + try: + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + result = response.json() + + if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]: + # If multiple embeddings are returned, we'll just use the first one + return result["data"][0]["embedding"] + else: + raise ValueError("Unexpected response format: 'embedding' data not found") + except requests.RequestException as e: + raise RuntimeError(f"Error calling embedding API: {str(e)}") + +async def send_lmstudio_request(api_url, base64_images, model, system_message, user_message, messages, seed, temperature, + max_tokens, top_k, top_p, repeat_penalty, stop, tools=None, tool_choice=None): + headers = { + "Content-Type": "application/json" + } + + data = { + "model": model, + "messages": prepare_lmstudio_messages(system_message, user_message, messages, base64_images), + "temperature": temperature, + "max_tokens": max_tokens, + "presence_penalty": repeat_penalty, + "top_p": top_p, + "top_k": top_k, + "seed": seed + } + + if stop: + data["stop"] = stop + if tools: + data["functions"] = tools + if tool_choice: + data["function_call"] = tool_choice + + try: + async with aiohttp.ClientSession() as session: + async with session.post(api_url, headers=headers, json=data) as response: + response.raise_for_status() + response_data = await response.json() + + choices = response_data.get('choices', []) + if choices: + choice = choices[0] + message = choice.get('message', {}) + if "function_call" in message: + return { + "choices": [{ + "message": { + "function_call": { + "name": message["function_call"]["name"], + "arguments": message["function_call"]["arguments"] + } + } + }] + } + else: + generated_text = message.get('content', '') + return { + "choices": [{ + "message": { + "content": generated_text + } + }] + } + else: + error_msg = "Error: No valid choices in the LMStudio response." + print(error_msg) + return {"choices": [{"message": {"content": error_msg}}]} + except aiohttp.ClientError as e: + error_msg = f"Error in LMStudio API request: {e}" + print(error_msg) + return {"choices": [{"message": {"content": error_msg}}]} + +def prepare_lmstudio_messages(base64_images, system_message, user_message, messages): + lmstudio_messages = [] + + if system_message: + lmstudio_messages.append({"role": "system", "content": system_message}) + + for message in messages: + role = message["role"] + content = message["content"] + + if role == "system": + lmstudio_messages.append({"role": "system", "content": content}) + elif role == "user": + lmstudio_messages.append({"role": "user", "content": content}) + elif role == "assistant": + lmstudio_messages.append({"role": "assistant", "content": content}) + + # Add the current user message with all images if provided + if base64_images: + content = [{"type": "text", "text": user_message}] + for base64_image in base64_images: + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + } + }) + lmstudio_messages.append({ + "role": "user", + "content": content + }) + print(f"Number of images sent: {len(base64_images)}") + else: + lmstudio_messages.append({"role": "user", "content": user_message}) + + return lmstudio_messages + + +"""def prepare_lmstudio_messages(system_message, user_message, messages, base64_images=None): + lmstudio_messages = [ + {"role": "system", "content": system_message}, + ] + + for message in messages: + if isinstance(message["content"], list): + # Handle multi-modal content + content = [] + for item in message["content"]: + if item["type"] == "text": + content.append(item["text"]) + elif item["type"] == "image_url": + content.append(f"[Image data: {item['image_url']['url']}]") + lmstudio_messages.append({"role": message["role"], "content": " ".join(content)}) + else: + lmstudio_messages.append(message) + + if base64_images: + image_content = "\n".join([f"[Image data: data:image/jpeg;base64,{img}]" for img in base64_images]) + lmstudio_messages.append({ + "role": "user", + "content": f"{user_message}\n{image_content}" + }) + else: + lmstudio_messages.append({"role": "user", "content": user_message}) + return lmstudio_messages""" \ No newline at end of file diff --git a/openai_api.py b/openai_api.py index f6118b8..65cad1d 100644 --- a/openai_api.py +++ b/openai_api.py @@ -176,7 +176,13 @@ def prepare_openai_messages(base64_images, system_message, user_message, message return openai_messages -async def generate_image(prompt: str, model: str = "dall-e-3", n: int = 1, size: str = "1024x1024", api_key: Optional[str] = None) -> List[str]: +async def generate_image( + prompt: str, + model: str = "dall-e-3", + n: int = 1, + size: str = "1024x1024", + api_key: Optional[str] = None +) -> List[str]: """ Generate images from a text prompt using DALL·E. @@ -197,22 +203,30 @@ async def generate_image(prompt: str, model: str = "dall-e-3", n: int = 1, size: "prompt": prompt, "n": n, "size": size, - "response_format": "url" # Change to "b64_json" for Base64 + "response_format": "b64_json" } async with aiohttp.ClientSession() as session: async with session.post(api_url, headers=headers, json=payload) as response: response.raise_for_status() data = await response.json() - images = [item["url"] for item in data.get("data", [])] + images = [item["b64_json"] for item in data.get("data", [])] return images -async def edit_image(image_path: str, mask_path: str, prompt: str, model: str = "dall-e-2", n: int = 1, size: str = "1024x1024", api_key: Optional[str] = None) -> List[str]: +async def edit_image( + image_base64: str, + mask_base64: str, + prompt: str, + model: str = "dall-e-2", + n: int = 1, + size: str = "1024x1024", + api_key: Optional[str] = None +) -> List[str]: """ Edit an existing image by replacing areas defined by a mask using DALL·E. - :param image_path: Path to the original image file. - :param mask_path: Path to the mask image file. + :param image_base64: Base64-encoded original image. + :param mask_base64: Base64-encoded mask image. :param prompt: The text prompt describing the desired edits. :param model: The model to use ("dall-e-2"). :param n: Number of edited images to generate. @@ -224,29 +238,36 @@ async def edit_image(image_path: str, mask_path: str, prompt: str, model: str = headers = { "Authorization": f"Bearer {api_key}" } + payload = { + "model": model, + "prompt": prompt, + "n": n, + "size": size, + "response_format": "b64_json" + } + files = { + "image": image_base64, + "mask": mask_base64 + } - with open(image_path, "rb") as img_file, open(mask_path, "rb") as mask_file: - files = { - "model": (None, model), - "image": (os.path.basename(image_path), img_file, "image/png"), - "mask": (os.path.basename(mask_path), mask_file, "image/png"), - "prompt": (None, prompt), - "n": (None, str(n)), - "size": (None, size) - } - - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, data=files) as response: - response.raise_for_status() - data = await response.json() - images = [item["url"] for item in data.get("data", [])] - return images + async with aiohttp.ClientSession() as session: + async with session.post(api_url, headers=headers, json=payload) as response: + response.raise_for_status() + data = await response.json() + images = [item["b64_json"] for item in data.get("data", [])] + return images -async def generate_image_variations(image_path: str, model: str = "dall-e-2", n: int = 1, size: str = "1024x1024", api_key: Optional[str] = None) -> List[str]: +async def generate_image_variations( + image_base64: str, + model: str = "dall-e-2", + n: int = 1, + size: str = "1024x1024", + api_key: Optional[str] = None +) -> List[str]: """ Generate variations of an existing image using DALL·E. - :param image_path: Path to the original image file. + :param image_base64: Base64-encoded original image. :param model: The model to use ("dall-e-2"). :param n: Number of variations to generate. :param size: Size of the generated images. @@ -257,21 +278,22 @@ async def generate_image_variations(image_path: str, model: str = "dall-e-2", n: headers = { "Authorization": f"Bearer {api_key}" } + payload = { + "model": model, + "n": n, + "size": size, + "response_format": "b64_json" + } + files = { + "image": image_base64 + } - with open(image_path, "rb") as img_file: - files = { - "model": (None, model), - "image": (os.path.basename(image_path), img_file, "image/png"), - "n": (None, str(n)), - "size": (None, size) - } - - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, data=files) as response: - response.raise_for_status() - data = await response.json() - images = [item["url"] for item in data.get("data", [])] - return images + async with aiohttp.ClientSession() as session: + async with session.post(api_url, headers=headers, json=payload) as response: + response.raise_for_status() + data = await response.json() + images = [item["b64_json"] for item in data.get("data", [])] + return images async def text_to_speech(text: str, model: str = "tts-1", voice: str = "alloy", response_format: str = "mp3", output_path: str = "speech.mp3", api_key: Optional[str] = None) -> None: """ diff --git a/pyproject.toml b/pyproject.toml index 66340d8..04e06dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,39 +1,39 @@ -[project] -name = "comfyui_if_ai_tools" -description = "ComfyUI-IF_AI_tools is a set of custom nodes to Run Local and API LLMs and LMMs, Features OCR-RAG (Bialdy), nanoGraphRAG, Supervision Object Detection, supports Ollama, LlamaCPP LMstudio, Koboldcpp, TextGen, Transformers or via APIs Anthropic, Groq, OpenAI, Google Gemini, Mistral, xAI and create your own charcters assistants (SystemPrompts) with custom presets and muchmore" -version = "1.0.1" -license = { file = "LICENSE.txt" } -dependencies = ["anthropic", -"groq", -"mistralai", -"huggingface_hub", -"pypdf2", -"pdf2image", -"timm", -"sentence-transformers", -"byaldi", -"opencv-python", -"IPython", -"python-dotenv", -"nltk", -"tiktoken", -"matplotlib", -"plotly", -"kaleido", -"networkx", -"fastparquet", -"pydantic", -"rich", -"supervision", -"nano-graphrag", -"qwen-vl-utils" -] - -[project.urls] -Repository = "https://github.com/if-ai/ComfyUI-IF_AI_tools" -# Used by Comfy Registry https://comfyregistry.org - -[tool.comfy] -PublisherId = "impactframes" -DisplayName = "ComfyUI_IF_AI_tools" -Icon = "" +[project] +name = "comfyui-if_ai_tools" +description = "Run Local and API LLMs, Features OCR-RAG (Bialdy), nanoGraphRAG, Supervision Object Detection, Conditioning manipulation via Omost, supports Ollama, LlamaCPP LMstudio, Koboldcpp, TextGen, Transformers or via APIs Anthropic, Groq, OpenAI, Google Gemini, Mistral, xAI and create your own charcters assistants (SystemPrompts) with custom presets and muchmore" +version = "1.0.1" +license = { file = "LICENSE.txt" } +dependencies = ["anthropic", +"groq", +"mistralai", +"huggingface_hub", +"pypdf2", +"pdf2image", +"timm", +"sentence-transformers", +"byaldi", +"opencv-python", +"IPython", +"python-dotenv", +"nltk", +"tiktoken", +"matplotlib", +"plotly", +"kaleido", +"networkx", +"fastparquet", +"pydantic", +"rich", +"supervision", +"nano-graphrag", +"qwen-vl-utils" +] + +[project.urls] +Repository = "https://github.com/if-ai/ComfyUI-IF_AI_tools" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "impactframes" +DisplayName = "ComfyUI-IF_AI_tools" +Icon = "" diff --git a/send_request.py b/send_request.py index 0982d66..a3a7be5 100644 --- a/send_request.py +++ b/send_request.py @@ -1,357 +1,402 @@ -#send_request.py -import aiohttp -import asyncio -import json -import logging -from typing import List, Union, Optional, Dict, Any -#from json_repair import repair_json - -# Existing imports -from .anthropic_api import send_anthropic_request -from .ollama_api import send_ollama_request, create_ollama_embedding -from .openai_api import send_openai_request, create_openai_compatible_embedding -from .xai_api import send_xai_request -from .kobold_api import send_kobold_request -from .groq_api import send_groq_request -from .lms_api import send_lmstudio_request -from .textgen_api import send_textgen_request -from .llamacpp_api import send_llama_cpp_request -from .mistral_api import send_mistral_request -from .vllm_api import send_vllm_request -from .gemini_api import send_gemini_request -from .transformers_api import TransformersModelManager # Import the manager -from .utils import convert_images_for_api, format_response -# Set up logging -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Initialize the TransformersModelManager -_transformers_manager = TransformersModelManager() # <-- Removed models_dir parameter - -"""class MockCompletion: - def __init__(self, **kwargs): - # Initialize all attributes to None - for key in ['choices', 'id', 'object', 'created', 'model', 'usage', 'message']: - setattr(self, key, None) - - # Update attributes based on kwargs - self.__dict__.update(kwargs) - - # Ensure 'choices' has at least one default choice if not provided - if not self.choices and hasattr(self, 'message') and self.message.get("content"): - self.choices = [{ - "message": { - "content": self.message["content"] - }, - "finish_reason": "stop", - "index": 0 - }]""" - -async def send_request( - llm_provider: str, - base_ip: str, - port: str, - images: List[str], - llm_model: str, - system_message: str, - user_message: str, - messages: List[Dict[str, Any]], - seed: Optional[int], - temperature: float, - max_tokens: int, - random: bool, - top_k: int, - top_p: float, - repeat_penalty: float, - stop: Optional[List[str]], - keep_alive: bool, - llm_api_key: Optional[str] = None, - tools: Optional[Any] = None, - tool_choice: Optional[Any] = None, - precision: Optional[str] = "fp16", - attention: Optional[str] = "sdpa", -) -> Union[str, Dict[str, Any]]: - """ - Sends a request to the specified LLM provider and returns a unified response. - - Args: - llm_provider (str): The LLM provider to use. - base_ip (str): Base IP address for the API. - port (int): Port number for the API. - base64_images (List[str]): List of images encoded in base64. - llm_model (str): The model to use. - system_message (str): System message for the LLM. - user_message (str): User message for the LLM. - messages (List[Dict[str, Any]]): Conversation messages. - seed (Optional[int]): Random seed. - temperature (float): Temperature for randomness. - max_tokens (int): Maximum tokens to generate. - random (bool): Whether to use randomness. - top_k (int): Top K for sampling. - top_p (float): Top P for sampling. - repeat_penalty (float): Penalty for repetition. - stop (Optional[List[str]]): Stop sequences. - keep_alive (bool): Whether to keep the session alive. - llm_api_key (Optional[str], optional): API key for the LLM provider. - tools (Optional[Any], optional): Tools to be used. - tool_choice (Optional[Any], optional): Tool choice. - precision (Optional[str], optional): Precision for the model. - attention (Optional[str], optional): Attention mechanism for the model. - - Returns: - Union[str, Dict[str, Any]]: Unified response format. - """ - try: - # Convert images to base64 format for API consumption - if llm_provider == "transformers": - # For transformers, we'll pass PIL images - pil_images = convert_images_for_api(images, target_format='pil') if images is not None and len(images) > 0 else None - response = await _transformers_manager.send_transformers_request( - model_name=llm_model, - system_message=system_message, - user_message=user_message, - messages=messages, - max_new_tokens=max_tokens, - images=pil_images, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stop_strings_list=stop, - repetition_penalty=repeat_penalty, - seed=seed, - keep_alive=keep_alive, - precision=precision, - attention=attention - ) - return response - else: - # For other providers, convert to base64 only if images exist - base64_images = convert_images_for_api(images, target_format='base64') if images is not None and len(images) > 0 else None - - api_functions = { - "groq": send_groq_request, - "anthropic": send_anthropic_request, - "openai": send_openai_request, - "xai": send_xai_request, - "kobold": send_kobold_request, - "ollama": send_ollama_request, - "lmstudio": send_lmstudio_request, - "textgen": send_textgen_request, - "llamacpp": send_llama_cpp_request, - "mistral": send_mistral_request, - "vllm": send_vllm_request, - "gemini": send_gemini_request, - "transformers": None, # Handled separately - } - - if llm_provider not in api_functions and llm_provider != "transformers": - raise ValueError(f"Invalid llm_provider: {llm_provider}") - - if llm_provider == "transformers": - # This should be handled above, but included for safety - raise ValueError("Transformers provider should be handled separately.") - else: - # Existing logic for other providers - api_function = api_functions[llm_provider] - # Prepare API-specific keyword arguments - kwargs = {} - - if llm_provider == "ollama": - api_url = f"http://{base_ip}:{port}/api/chat" - kwargs = dict( - api_url=api_url, - base64_images=base64_images, - model=llm_model, - system_message=system_message, - user_message=user_message, - messages=messages, - seed=seed, - temperature=temperature, - max_tokens=max_tokens, - random=random, - top_k=top_k, - top_p=top_p, - repeat_penalty=repeat_penalty, - stop=stop, - keep_alive=keep_alive, - tools=tools, - tool_choice=tool_choice, - ) - elif llm_provider in ["kobold", "lmstudio", "textgen", "llamacpp", "vllm"]: - api_url = f"http://{base_ip}:{port}/v1/chat/completions" - kwargs = { - "api_url": api_url, - "base64_images": base64_images, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "seed": seed, - "temperature": temperature, - "max_tokens": max_tokens, - "top_k": top_k, - "top_p": top_p, - "repeat_penalty": repeat_penalty, - "stop": stop, - "tools": tools, - "tool_choice": tool_choice, - } - if llm_provider == "llamacpp": - kwargs.pop("tool_choice", None) - elif llm_provider == "vllm": - kwargs["api_key"] = llm_api_key - elif llm_provider == "gemini": - kwargs = { - "base64_images": base64_images, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "top_k": top_k, - "top_p": top_p, - "stop": stop, - "api_key": llm_api_key, - "tools": tools, - "tool_choice": tool_choice, - } - elif llm_provider == "openai": - api_url = f"https://api.openai.com/v1/chat/completions" - kwargs = { - "api_url": api_url, - "base64_images": base64_images, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "api_key": llm_api_key, - "seed": seed if random else None, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p, - "repeat_penalty": repeat_penalty, - "tools": tools, - "tool_choice": tool_choice, - } - elif llm_provider == "xai": - api_url = f"https://api.x.ai/v1/chat/completions" - kwargs = { - "api_url": api_url, - "base64_images": base64_images, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "api_key": llm_api_key, - "seed": seed if random else None, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p, - "repeat_penalty": repeat_penalty, - "tools": tools, - "tool_choice": tool_choice, - } - elif llm_provider == "anthropic": - kwargs = { - "api_key": llm_api_key, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "base64_images": base64_images, - "tools": tools, - "tool_choice": tool_choice - } - elif llm_provider == "groq": - kwargs = { - "base64_images": base64_images, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "api_key": llm_api_key, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p, - "tools": tools, - "tool_choice": tool_choice, - } - elif llm_provider == "mistral": - kwargs = { - "base64_images": base64_images, - "model": llm_model, - "system_message": system_message, - "user_message": user_message, - "messages": messages, - "api_key": llm_api_key, - "seed": seed if random else None, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p, - "tools": tools, - "tool_choice": tool_choice, - } - else: - raise ValueError(f"Unsupported llm_provider: {llm_provider}") - - response = await api_function(**kwargs) - - if isinstance(response, dict): - choices = response.get("choices", []) - if choices and "content" in choices[0].get("message", {}): - content = choices[0]["message"]["content"] - if content.startswith("Error:"): - print(f"Error from {llm_provider} API: {content}") - if tools: - return response - else: - try: - return response["choices"][0]["message"]["content"] - except (KeyError, IndexError, TypeError) as e: - error_msg = f"Error formatting response: {str(e)}" - logger.error(error_msg) - return {"choices": [{"message": {"content": error_msg}}]} - - except Exception as e: - logger.error(f"Exception in send_request: {str(e)}", exc_info=True) - return {"choices": [{"message": {"content": f"Exception: {str(e)}"}}]} - -def response_format_handler(response: Dict[str, Any], tools: Optional[Any]) -> Union[str, Dict[str, Any]]: - """ - Formats the response based on the desired response format. - - Args: - response (Dict[str, Any]): The raw response from the API. - tools (Optional[Any]): Tools that might affect the response. - response_format (str): 'text' or 'json'. - - Returns: - Union[str, Dict[str, Any]]: Formatted response. - """ - if tools: - return response - else: - try: - return response["choices"][0]["message"]["content"] - except (KeyError, IndexError, TypeError) as e: - error_msg = f"Error formatting response: {str(e)}" - logger.error(error_msg) - return {"choices": [{"message": {"content": error_msg}}]} - -async def create_embedding(embedding_provider: str, api_base: str, embedding_model: str, input: Union[str, List[str]], embedding_api_key: Optional[str] = None) -> Union[List[float], None]: # Correct return type hint - if embedding_provider == "ollama": - return await create_ollama_embedding(api_base, embedding_model, input) - - - elif embedding_provider in ["openai", "lmstudio", "llamacpp", "textgen", "mistral", "xai"]: - try: - return await create_openai_compatible_embedding(api_base, embedding_model, input, embedding_api_key) # Try block for more precise error handling - except ValueError as e: - print(f"Error creating embedding: {e}") # Log the specific error - return None # Return None on error - - else: - raise ValueError(f"Unsupported embedding_provider: {embedding_provider}") +#send_request.py +import aiohttp +import asyncio +import json +import logging +from typing import List, Union, Optional, Dict, Any +#from json_repair import repair_json + +# Existing imports +from .anthropic_api import send_anthropic_request +from .ollama_api import send_ollama_request, create_ollama_embedding +from .openai_api import send_openai_request, create_openai_compatible_embedding, generate_image, generate_image_variations, edit_image +from .xai_api import send_xai_request +from .kobold_api import send_kobold_request +from .groq_api import send_groq_request +from .lms_api import send_lmstudio_request +from .textgen_api import send_textgen_request +from .llamacpp_api import send_llama_cpp_request +from .mistral_api import send_mistral_request +from .vllm_api import send_vllm_request +from .gemini_api import send_gemini_request +from .transformers_api import TransformersModelManager # Import the manager +from .utils import format_images_for_provider, convert_images_for_api, format_response +# Set up logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Initialize the TransformersModelManager +_transformers_manager = TransformersModelManager() + + +def run_async(coroutine): + """Helper function to run coroutines in a new event loop if necessary""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coroutine) + +async def send_request( + llm_provider: str, + base_ip: str, + port: str, + images: List[str], + llm_model: str, + system_message: str, + user_message: str, + messages: List[Dict[str, Any]], + seed: Optional[int], + temperature: float, + max_tokens: int, + random: bool, + top_k: int, + top_p: float, + repeat_penalty: float, + stop: Optional[List[str]], + keep_alive: bool, + llm_api_key: Optional[str] = None, + tools: Optional[Any] = None, + tool_choice: Optional[Any] = None, + precision: Optional[str] = "fp16", + attention: Optional[str] = "sdpa", + aspect_ratio: Optional[str] = "1:1", + strategy: Optional[str] = "normal", + batch_count: Optional[int] = 4, + mask: Optional[str] = None, +) -> Union[str, Dict[str, Any]]: + """ + Sends a request to the specified LLM provider and returns a unified response. + + Args: + llm_provider (str): The LLM provider to use. + base_ip (str): Base IP address for the API. + port (int): Port number for the API. + base64_images (List[str]): List of images encoded in base64. + llm_model (str): The model to use. + system_message (str): System message for the LLM. + user_message (str): User message for the LLM. + messages (List[Dict[str, Any]]): Conversation messages. + seed (Optional[int]): Random seed. + temperature (float): Temperature for randomness. + max_tokens (int): Maximum tokens to generate. + random (bool): Whether to use randomness. + top_k (int): Top K for sampling. + top_p (float): Top P for sampling. + repeat_penalty (float): Penalty for repetition. + stop (Optional[List[str]]): Stop sequences. + keep_alive (bool): Whether to keep the session alive. + llm_api_key (Optional[str], optional): API key for the LLM provider. + tools (Optional[Any], optional): Tools to be used. + tool_choice (Optional[Any], optional): Tool choice. + precision (Optional[str], optional): Precision for the model. + attention (Optional[str], optional): Attention mechanism for the model. + aspect_ratio (Optional[str], optional): Desired aspect ratio for image generation/editing. + Options: "1:1", "4:5", "3:4", "5:4", "16:9", "9:16". Defaults to "1:1". + image_mode (Optional[str], optional): Mode for image processing. + Options: "create", "edit", "variations". Defaults to "create". + + Returns: + Union[str, Dict[str, Any]]: Unified response format. + """ + try: + #formatted_images = format_images_for_provider(images, llm_provider) if images is not None else None + #formatted_mask = format_images_for_provider(mask, llm_provider) if mask is not None else None + # Define aspect ratio to size mapping + aspect_ratio_mapping = { + "1:1": "1024x1024", + "4:5": "1024x1280", + "3:4": "1024x1365", + "5:4": "1280x1024", + "16:9": "1600x900", + "9:16": "900x1600" + } + + # Get the size based on the provided aspect_ratio + size = aspect_ratio_mapping.get(aspect_ratio.lower(), "1024x1024") # Default to square if invalid + + # Convert images to base64 format for API consumption + if llm_provider == "transformers": + + formatted_images = convert_images_for_api(images, target_format='pil') if images is not None and len(images) > 0 else None + response = await _transformers_manager.send_transformers_request( + model_name=llm_model, + system_message=system_message, + user_message=user_message, + messages=messages, + max_new_tokens=max_tokens, + images=formatted_images, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_strings_list=stop, + repetition_penalty=repeat_penalty, + seed=seed, + keep_alive=keep_alive, + precision=precision, + attention=attention + ) + return response + else: + # For other providers, convert to base64 only if images exist + formatted_images = convert_images_for_api(images, target_format='base64') if images is not None and len(images) > 0 else None + formatted_mask = convert_images_for_api(mask, target_format='base64') if mask is not None and len(mask) > 0 else None + + api_functions = { + "groq": send_groq_request, + "anthropic": send_anthropic_request, + "openai": send_openai_request, + "xai": send_xai_request, + "kobold": send_kobold_request, + "ollama": send_ollama_request, + "lmstudio": send_lmstudio_request, + "textgen": send_textgen_request, + "llamacpp": send_llama_cpp_request, + "mistral": send_mistral_request, + "vllm": send_vllm_request, + "gemini": send_gemini_request, + "transformers": None, # Handled separately + } + + if llm_provider not in api_functions and llm_provider != "transformers": + raise ValueError(f"Invalid llm_provider: {llm_provider}") + + if llm_provider == "transformers": + # This should be handled above, but included for safety + raise ValueError("Transformers provider should be handled separately.") + else: + # Existing logic for other providers + api_function = api_functions[llm_provider] + # Prepare API-specific keyword arguments + kwargs = {} + + if llm_provider == "ollama": + api_url = f"http://{base_ip}:{port}/api/chat" + kwargs = dict( + api_url=api_url, + base64_images=formatted_images, + model=llm_model, + system_message=system_message, + user_message=user_message, + messages=messages, + seed=seed, + temperature=temperature, + max_tokens=max_tokens, + random=random, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + stop=stop, + keep_alive=keep_alive, + tools=tools, + tool_choice=tool_choice, + ) + elif llm_provider in ["kobold", "lmstudio", "textgen", "llamacpp", "vllm"]: + api_url = f"http://{base_ip}:{port}/v1/chat/completions" + kwargs = { + "api_url": api_url, + "base64_images": formatted_images, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "seed": seed, + "temperature": temperature, + "max_tokens": max_tokens, + "top_k": top_k, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + "stop": stop, + "tools": tools, + "tool_choice": tool_choice, + } + if llm_provider == "llamacpp": + kwargs.pop("tool_choice", None) + elif llm_provider == "vllm": + kwargs["api_key"] = llm_api_key + elif llm_provider == "gemini": + kwargs = { + "base64_images": formatted_images, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_k": top_k, + "top_p": top_p, + "stop": stop, + "api_key": llm_api_key, + "tools": tools, + "tool_choice": tool_choice, + } + elif llm_provider == "openai": + if llm_model.startswith("dall-e"): + if strategy == "create": + # Generate image + generated_images = await generate_image( + prompt=user_message, + model=llm_model, + n=batch_count, + size=size, + api_key=llm_api_key + ) + return {"images": generated_images} + elif strategy == "edit": + + # Edit image + edited_images = await edit_image( + image_base64=formatted_images[0], + mask_base64=formatted_mask, + prompt=user_message, + model=llm_model, + n=batch_count, + size=size, + api_key=llm_api_key + ) + return {"images": edited_images} + elif strategy == "variations": + # Generate variations + variations_images = await generate_image_variations( + image_base64=formatted_images[0], + model=llm_model, + n=batch_count, + size=size, + api_key=llm_api_key + ) + return {"images": variations_images} + else: + api_url = f"https://api.openai.com/v1/chat/completions" + kwargs = { + "api_url": api_url, + "base64_images": formatted_images, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "api_key": llm_api_key, + "seed": seed if random else None, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + "tools": tools, + "tool_choice": tool_choice, + } + elif llm_provider == "xai": + api_url = f"https://api.x.ai/v1/chat/completions" + kwargs = { + "api_url": api_url, + "base64_images": formatted_images, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "api_key": llm_api_key, + "seed": seed if random else None, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + "tools": tools, + "tool_choice": tool_choice, + } + elif llm_provider == "anthropic": + kwargs = { + "api_key": llm_api_key, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "base64_images": formatted_images, + "tools": tools, + "tool_choice": tool_choice + } + elif llm_provider == "groq": + kwargs = { + "base64_images": formatted_images, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "api_key": llm_api_key, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "tools": tools, + "tool_choice": tool_choice, + } + elif llm_provider == "mistral": + kwargs = { + "base64_images": formatted_images, + "model": llm_model, + "system_message": system_message, + "user_message": user_message, + "messages": messages, + "api_key": llm_api_key, + "seed": seed if random else None, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "tools": tools, + "tool_choice": tool_choice, + } + else: + raise ValueError(f"Unsupported llm_provider: {llm_provider}") + + response = await api_function(**kwargs) + + # Ensure response is properly awaited if it's a coroutine + if asyncio.iscoroutine(response): + response = await response + + if isinstance(response, dict): + choices = response.get("choices", []) + if choices and "content" in choices[0].get("message", {}): + content = choices[0]["message"]["content"] + if content.startswith("Error:"): + print(f"Error from {llm_provider} API: {content}") + if tools: + return response + else: + try: + return response["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as e: + error_msg = f"Error formatting response: {str(e)}" + logger.error(error_msg) + return {"choices": [{"message": {"content": error_msg}}]} + + except Exception as e: + logger.error(f"Exception in send_request: {str(e)}", exc_info=True) + return {"choices": [{"message": {"content": f"Exception: {str(e)}"}}]} + +def format_response(response, tools): + """Helper function to format the response consistently""" + if tools: + return response + try: + if isinstance(response, dict) and "choices" in response: + return response["choices"][0]["message"]["content"] + return response + except (KeyError, IndexError, TypeError) as e: + error_msg = f"Error formatting response: {str(e)}" + logger.error(error_msg) + return {"choices": [{"message": {"content": error_msg}}]} + +async def create_embedding(embedding_provider: str, api_base: str, embedding_model: str, input: Union[str, List[str]], embedding_api_key: Optional[str] = None) -> Union[List[float], None]: # Correct return type hint + if embedding_provider == "ollama": + return await create_ollama_embedding(api_base, embedding_model, input) + + + elif embedding_provider in ["openai", "lmstudio", "llamacpp", "textgen", "mistral", "xai"]: + try: + return await create_openai_compatible_embedding(api_base, embedding_model, input, embedding_api_key) # Try block for more precise error handling + except ValueError as e: + print(f"Error creating embedding: {e}") + return None # Return None on error + + else: + raise ValueError(f"Unsupported embedding_provider: {embedding_provider}") diff --git a/superflorence.py b/superflorence.py index a6fc5e3..f6da650 100644 --- a/superflorence.py +++ b/superflorence.py @@ -1,467 +1,467 @@ -import torch -import torchvision.transforms.functional as F -import numpy as np -import logging -from PIL import Image, ImageDraw, ImageFont, ImageColor -import supervision as sv -from io import BytesIO -import base64 -import json -import random -import os -import re -import comfy.model_management as mm -from .transformers_api import TransformersModelManager -from torchvision.transforms import functional as TF -from supervision.detection.lmm import from_florence_2 -from json import JSONEncoder -from typing import Tuple, Optional, List, Union -import folder_paths - -logger = logging.getLogger(__name__) - -class NumpyEncoder(JSONEncoder): - """Custom JSON Encoder that handles NumPy arrays and torch tensors.""" - def default(self, obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - if isinstance(obj, np.integer): - return int(obj) - if isinstance(obj, np.floating): - return float(obj) - if isinstance(obj, torch.Tensor): - return obj.cpu().numpy().tolist() - return super().default(obj) - -SUPPORTED_TASKS_FLORENCE_2 = [ - "", - "", - "", - "", - "", - "", - "", - "", - "", - "" -] - -def process_mask(mask, image_size=None): - """Process mask to ensure compatibility with ComfyUI.""" - if mask is None: - return None - - # Convert to numpy if tensor - if isinstance(mask, torch.Tensor): - mask = mask.cpu().numpy() - - # Handle boolean masks - if mask.dtype == bool: - mask = mask.astype(np.float32) - - # Ensure float32 - mask = mask.astype(np.float32) - - # Convert to tensor - mask = torch.from_numpy(mask) - - # Handle different shapes - if len(mask.shape) == 2: # Single mask - mask = mask.unsqueeze(0) - elif len(mask.shape) == 3: # Multiple masks - if mask.shape[0] > 1: # Multiple masks to combine - if image_size is not None: # Resize individual masks if needed - W, H = image_size - resized_masks = [] - for m in mask: - if m.shape != (H, W): - m = F.interpolate( - m.unsqueeze(0).unsqueeze(0), - size=(H, W), - mode='nearest' - ).squeeze() - resized_masks.append(m) - mask = torch.stack(resized_masks) - # Combine masks if multiple - mask = mask.any(dim=0).float().unsqueeze(0) - - # Final resize if needed - if image_size is not None: - W, H = image_size - if mask.shape[-2:] != (H, W): - mask = F.interpolate( - mask.unsqueeze(0), - size=(H, W), - mode='nearest' - ).squeeze(0) - - return mask - -def process_mask_selection(masks, selection, labels=None): - """Process mask selection based on indices or labels.""" - if not selection or masks is None: - return masks - - selections = selection.split(',') - mask_indices = [] - - for sel in selections: - sel = sel.strip() - if sel.isdigit(): - idx = int(sel) - if 0 <= idx < len(masks): - mask_indices.append(idx) - elif labels is not None: - for i, label in enumerate(labels): - if sel.lower() in label.lower(): - mask_indices.append(i) - - if not mask_indices: - return masks - - selected_masks = masks[mask_indices] - return selected_masks.any(dim=0).float().unsqueeze(0) - -def generate_mask_from_box(box: np.ndarray, image_size: Tuple[int, int]) -> np.ndarray: - """ - Generate a binary mask from a bounding box. - - Args: - box (np.ndarray): Array of [x1, y1, x2, y2] coordinates - image_size (Tuple[int, int]): (width, height) of the image - - Returns: - np.ndarray: Binary mask array of shape (H, W) - """ - W, H = image_size - mask = np.zeros((H, W), dtype=np.bool_) - x1, y1, x2, y2 = map(lambda x: max(0, int(x)), box) # Ensure non-negative integers - x2 = min(x2, W) # Ensure within image bounds - y2 = min(y2, H) - if x2 > x1 and y2 > y1: # Only set mask if box is valid - mask[y1:y2, x1:x2] = True - return mask - - - -class FlorenceModule: - def __init__(self): - self.model_manager = TransformersModelManager() - self.device = self.model_manager.device - self.offload_device = self.model_manager.offload_device - self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI_IF_AI_tools", "IF_AI", "placeholder.png") - - self.box_annotator = sv.BoxAnnotator( - color=sv.ColorPalette.DEFAULT, - thickness=2, - color_lookup=sv.ColorLookup.CLASS - ) - self.label_annotator = sv.LabelAnnotator( - color=sv.ColorPalette.DEFAULT, - text_color=sv.Color.WHITE, - text_scale=0.5, - text_thickness=1, - text_padding=10, - text_position=sv.Position.TOP_LEFT, - color_lookup=sv.ColorLookup.CLASS - ) - self.mask_annotator = sv.MaskAnnotator( - color=sv.ColorPalette.DEFAULT, - opacity=0.5, - color_lookup=sv.ColorLookup.CLASS - ) - - self.colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', - 'olive', 'cyan', 'red', 'lime', 'indigo', 'violet', - 'aqua', 'magenta', 'gold', 'tan', 'skyblue'] - - def prepare_json_output(self, out_data): - """Prepare output data for JSON serialization.""" - try: - return json.dumps(out_data, cls=NumpyEncoder, indent=2) - except Exception as e: - logger.error(f"Error serializing output data: {e}") - return json.dumps({"error": str(e)}) - - def format_output_data(self, detections, labels, mask, W, H, task): - """Format detection data for output.""" - try: - output = { - "boxes": detections.xyxy.tolist() if detections.xyxy is not None else [], - "task": task, - "dimensions": {"width": W, "height": H} - } - - if labels is not None: - if isinstance(labels, (np.ndarray, torch.Tensor)): - output["labels"] = labels.tolist() - else: - output["labels"] = list(labels) - else: - output["labels"] = [] - - if mask is not None: - if isinstance(mask, (np.ndarray, torch.Tensor)): - output["has_mask"] = True - output["mask_shape"] = list(mask.shape) - else: - output["has_mask"] = False - else: - output["has_mask"] = False - - return output - except Exception as e: - logger.error(f"Error formatting output data: {e}") - return {"error": str(e)} - - def parse_florence_response(self, text): - """Parse Florence response to extract labels and locations.""" - pattern = r'([^<]+)(?:)' - matches = re.finditer(pattern, text) - - labels = [] - locations = [] - - for match in matches: - label = match.group(1).strip() - coords = [int(match.group(i)) for i in range(2, 6)] - labels.append(label) - locations.append(coords) - - return labels, np.array(locations) if locations else np.array([]) - - def validate_task(self, task_prompt: str) -> str: - """Validate and format task prompt.""" - task_key = f"<{task_prompt.upper()}>" if not task_prompt.startswith("<") else task_prompt - if task_key not in SUPPORTED_TASKS_FLORENCE_2: - raise ValueError(f"Task {task_key} not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}") - return task_key - - def handle_task_specific_processing(self, task: str, response, W: int, H: int): - """Handle task-specific processing for Florence output.""" - xyxy, labels, mask, xyxyxyxy = from_florence_2(response, (W, H)) - - # Task-specific processing - if task in ["", ""]: - # These tasks return masks directly - if mask is None: - logger.warning(f"No mask returned for segmentation task {task}") - return xyxy, labels, mask, xyxyxyxy - - elif task in ["", "", ""]: - # These tasks return boxes and labels - if xyxy is None or len(xyxy) == 0: - logger.warning(f"No boxes returned for detection task {task}") - return xyxy, labels, mask, xyxyxyxy - - # Generate masks from boxes - if mask is None: - try: - masks_list = [] - for box in xyxy: - box_mask = generate_mask_from_box(box, (W, H)) - if box_mask is not None: - masks_list.append(box_mask) - mask = np.stack(masks_list) if masks_list else None - logger.debug(f"Generated {len(masks_list)} masks from boxes for {task}") - except Exception as e: - logger.error(f"Error generating masks from boxes: {e}") - mask = None - - elif task == "": - # Handle OCR with special oriented boxes - if xyxyxyxy is not None: - logger.debug(f"Processing OCR with oriented boxes: {len(xyxyxyxy)} regions") - - elif task in ["", ""]: - # These tasks return a single region with description - if labels is not None and len(labels) > 0: - logger.debug(f"Region description: {labels[0]}") - - return xyxy, labels, mask, xyxyxyxy - - async def run_florence(self, images, task, task_prompt, llm_model, precision, attention, - fill_mask, output_mask_select, keep_alive, max_new_tokens, - temperature, top_p, top_k, repetition_penalty, seed, text_input): - try: - # Validate task and format prompt - task_key = self.validate_task(task_prompt) - prompt = f"{task_key} {text_input}" if text_input else task_key - logger.debug(f"Using task: {task_key} with prompt: {prompt}") - - if len(images.shape) == 3: - images = images.unsqueeze(0) - images = images.permute(0, 3, 1, 2) - - out = [] - out_masks = [] - out_results = [] - out_data = [] - - for img in images: - try: - image_pil = TF.to_pil_image(img) - W, H = image_pil.size - - result = await self.model_manager.send_transformers_request( - model_name=llm_model, - system_message="", - user_message=prompt, - messages=[], - max_new_tokens=max_new_tokens, - images=[image_pil], - temperature=temperature, - top_p=top_p, - top_k=top_k, - stop_strings_list=["<|endoftext|>"], - repetition_penalty=repetition_penalty, - seed=seed, - keep_alive=keep_alive, - precision=precision, - attention=attention - ) - - generated_text = result[0] - response = result[1][0] if isinstance(result[1], list) else result[1] - - # Process Florence output with task-specific handling - xyxy, labels, mask, xyxyxyxy = self.handle_task_specific_processing( - task_key, response, W, H - ) - - # Generate masks for bounding boxes if no mask was provided - if mask is None and xyxy is not None and len(xyxy) > 0: - try: - masks_list = [] - for box in xyxy: - box_mask = generate_mask_from_box(box, (W, H)) - if box_mask is not None: - masks_list.append(box_mask) - mask = np.stack(masks_list) if masks_list else None - logger.debug(f"Generated {len(masks_list)} masks from boxes") - except Exception as e: - logger.error(f"Error generating masks from boxes: {e}") - mask = None - - # Create detections object - detections = sv.Detections( - xyxy=xyxy, - mask=mask, - class_id=np.arange(len(labels)) if labels is not None else None, - data={"class_name": labels} if labels is not None else None - ) - - # Create annotated image - annotated_frame = np.array(image_pil) - - # Process and apply masks - if mask is not None: - # Handle mask selection - if output_mask_select: - selected_indices = [] - selections = output_mask_select.split(',') - - for sel in selections: - sel = sel.strip().lower() - # Check for label match - if labels is not None: - for idx, label in enumerate(labels): - if sel in label.lower(): - selected_indices.append(idx) - # Check for numeric index - elif sel.isdigit(): - idx = int(sel) - if 0 <= idx < len(mask): - selected_indices.append(idx) - - if selected_indices: - selected_mask = np.zeros_like(mask[0]) - for idx in selected_indices: - selected_mask = np.logical_or(selected_mask, mask[idx]) - mask = np.array([selected_mask]) - - if fill_mask: - detections.mask = mask - annotated_frame = self.mask_annotator.annotate( - scene=annotated_frame, - detections=detections - ) - - # Convert mask for output - processed_mask = process_mask(mask, (W, H)) - if processed_mask is not None: - out_masks.append(processed_mask) - - # Draw boxes and labels - annotated_frame = self.box_annotator.annotate( - scene=annotated_frame, - detections=detections - ) - - if labels is not None: - formatted_labels = [] - for idx, label in enumerate(labels): - if output_mask_select: - if str(idx) in output_mask_select.split(",") or \ - any(sel.lower() in label.lower() for sel in output_mask_select.split(",")): - formatted_labels.append(f"[{idx}] {label}") - else: - formatted_labels.append(label) - else: - formatted_labels.append(f"{label}") - - annotated_frame = self.label_annotator.annotate( - scene=annotated_frame, - detections=detections, - labels=formatted_labels - ) - - # Convert to tensor - annotated_frame = Image.fromarray(annotated_frame) - out_tensor = TF.to_tensor(annotated_frame).unsqueeze(0).permute(0, 2, 3, 1).cpu().float() - out.append(out_tensor) - - # Store results - out_results.append(generated_text) - out_data.append(self.format_output_data(detections, labels, mask, W, H, task)) - - except Exception as e: - logger.error(f"Error processing image: {str(e)}") - continue - - # Combine outputs - if len(out) > 0: - out_tensor = torch.cat(out, dim=0) - if len(out_masks) > 0: - masks_tensor = torch.cat(out_masks, dim=0) - else: - masks_tensor = torch.zeros((1, out_tensor.shape[1], out_tensor.shape[2]), dtype=torch.float32) - else: - out_tensor = torch.zeros((1, 64, 64, 3), dtype=torch.float32) - masks_tensor = torch.zeros((1, 64, 64), dtype=torch.float32) - - if not keep_alive: - self.model_manager.unload_model(llm_model) - - return { - "Question": text_input, - "Response": out_results[0] if len(out_results) == 1 else out_results, - "Negative": "", - "Tool_Output": self.prepare_json_output(out_data), - "Retrieved_Image": out_tensor, - "Mask": masks_tensor - } - - except Exception as e: - logger.error(f"Error in run_florence: {str(e)}", exc_info=True) - # Return valid tensors even in case of error - error_data = {"error": str(e)} - return { - "Question": text_input, - "Response": f"Error: {str(e)}", - "Negative": "", - "Tool_Output": json.dumps(error_data), - "Retrieved_Image": torch.zeros((1, 64, 64, 3), dtype=torch.float32), - "Mask": torch.zeros((1, 64, 64), dtype=torch.float32) +import torch +import torchvision.transforms.functional as F +import numpy as np +import logging +from PIL import Image, ImageDraw, ImageFont, ImageColor +import supervision as sv +from io import BytesIO +import base64 +import json +import random +import os +import re +import comfy.model_management as mm +from .transformers_api import TransformersModelManager +from torchvision.transforms import functional as TF +from supervision.detection.lmm import from_florence_2 +from json import JSONEncoder +from typing import Tuple, Optional, List, Union +import folder_paths + +logger = logging.getLogger(__name__) + +class NumpyEncoder(JSONEncoder): + """Custom JSON Encoder that handles NumPy arrays and torch tensors.""" + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + return super().default(obj) + +SUPPORTED_TASKS_FLORENCE_2 = [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" +] + +def process_mask(mask, image_size=None): + """Process mask to ensure compatibility with ComfyUI.""" + if mask is None: + return None + + # Convert to numpy if tensor + if isinstance(mask, torch.Tensor): + mask = mask.cpu().numpy() + + # Handle boolean masks + if mask.dtype == bool: + mask = mask.astype(np.float32) + + # Ensure float32 + mask = mask.astype(np.float32) + + # Convert to tensor + mask = torch.from_numpy(mask) + + # Handle different shapes + if len(mask.shape) == 2: # Single mask + mask = mask.unsqueeze(0) + elif len(mask.shape) == 3: # Multiple masks + if mask.shape[0] > 1: # Multiple masks to combine + if image_size is not None: # Resize individual masks if needed + W, H = image_size + resized_masks = [] + for m in mask: + if m.shape != (H, W): + m = F.interpolate( + m.unsqueeze(0).unsqueeze(0), + size=(H, W), + mode='nearest' + ).squeeze() + resized_masks.append(m) + mask = torch.stack(resized_masks) + # Combine masks if multiple + mask = mask.any(dim=0).float().unsqueeze(0) + + # Final resize if needed + if image_size is not None: + W, H = image_size + if mask.shape[-2:] != (H, W): + mask = F.interpolate( + mask.unsqueeze(0), + size=(H, W), + mode='nearest' + ).squeeze(0) + + return mask + +def process_mask_selection(masks, selection, labels=None): + """Process mask selection based on indices or labels.""" + if not selection or masks is None: + return masks + + selections = selection.split(',') + mask_indices = [] + + for sel in selections: + sel = sel.strip() + if sel.isdigit(): + idx = int(sel) + if 0 <= idx < len(masks): + mask_indices.append(idx) + elif labels is not None: + for i, label in enumerate(labels): + if sel.lower() in label.lower(): + mask_indices.append(i) + + if not mask_indices: + return masks + + selected_masks = masks[mask_indices] + return selected_masks.any(dim=0).float().unsqueeze(0) + +def generate_mask_from_box(box: np.ndarray, image_size: Tuple[int, int]) -> np.ndarray: + """ + Generate a binary mask from a bounding box. + + Args: + box (np.ndarray): Array of [x1, y1, x2, y2] coordinates + image_size (Tuple[int, int]): (width, height) of the image + + Returns: + np.ndarray: Binary mask array of shape (H, W) + """ + W, H = image_size + mask = np.zeros((H, W), dtype=np.bool_) + x1, y1, x2, y2 = map(lambda x: max(0, int(x)), box) # Ensure non-negative integers + x2 = min(x2, W) # Ensure within image bounds + y2 = min(y2, H) + if x2 > x1 and y2 > y1: # Only set mask if box is valid + mask[y1:y2, x1:x2] = True + return mask + + + +class FlorenceModule: + def __init__(self): + self.model_manager = TransformersModelManager() + self.device = self.model_manager.device + self.offload_device = self.model_manager.offload_device + self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "placeholder.png") + + self.box_annotator = sv.BoxAnnotator( + color=sv.ColorPalette.DEFAULT, + thickness=2, + color_lookup=sv.ColorLookup.CLASS + ) + self.label_annotator = sv.LabelAnnotator( + color=sv.ColorPalette.DEFAULT, + text_color=sv.Color.WHITE, + text_scale=0.5, + text_thickness=1, + text_padding=10, + text_position=sv.Position.TOP_LEFT, + color_lookup=sv.ColorLookup.CLASS + ) + self.mask_annotator = sv.MaskAnnotator( + color=sv.ColorPalette.DEFAULT, + opacity=0.5, + color_lookup=sv.ColorLookup.CLASS + ) + + self.colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', + 'olive', 'cyan', 'red', 'lime', 'indigo', 'violet', + 'aqua', 'magenta', 'gold', 'tan', 'skyblue'] + + def prepare_json_output(self, out_data): + """Prepare output data for JSON serialization.""" + try: + return json.dumps(out_data, cls=NumpyEncoder, indent=2) + except Exception as e: + logger.error(f"Error serializing output data: {e}") + return json.dumps({"error": str(e)}) + + def format_output_data(self, detections, labels, mask, W, H, task): + """Format detection data for output.""" + try: + output = { + "boxes": detections.xyxy.tolist() if detections.xyxy is not None else [], + "task": task, + "dimensions": {"width": W, "height": H} + } + + if labels is not None: + if isinstance(labels, (np.ndarray, torch.Tensor)): + output["labels"] = labels.tolist() + else: + output["labels"] = list(labels) + else: + output["labels"] = [] + + if mask is not None: + if isinstance(mask, (np.ndarray, torch.Tensor)): + output["has_mask"] = True + output["mask_shape"] = list(mask.shape) + else: + output["has_mask"] = False + else: + output["has_mask"] = False + + return output + except Exception as e: + logger.error(f"Error formatting output data: {e}") + return {"error": str(e)} + + def parse_florence_response(self, text): + """Parse Florence response to extract labels and locations.""" + pattern = r'([^<]+)(?:)' + matches = re.finditer(pattern, text) + + labels = [] + locations = [] + + for match in matches: + label = match.group(1).strip() + coords = [int(match.group(i)) for i in range(2, 6)] + labels.append(label) + locations.append(coords) + + return labels, np.array(locations) if locations else np.array([]) + + def validate_task(self, task_prompt: str) -> str: + """Validate and format task prompt.""" + task_key = f"<{task_prompt.upper()}>" if not task_prompt.startswith("<") else task_prompt + if task_key not in SUPPORTED_TASKS_FLORENCE_2: + raise ValueError(f"Task {task_key} not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}") + return task_key + + def handle_task_specific_processing(self, task: str, response, W: int, H: int): + """Handle task-specific processing for Florence output.""" + xyxy, labels, mask, xyxyxyxy = from_florence_2(response, (W, H)) + + # Task-specific processing + if task in ["", ""]: + # These tasks return masks directly + if mask is None: + logger.warning(f"No mask returned for segmentation task {task}") + return xyxy, labels, mask, xyxyxyxy + + elif task in ["", "", ""]: + # These tasks return boxes and labels + if xyxy is None or len(xyxy) == 0: + logger.warning(f"No boxes returned for detection task {task}") + return xyxy, labels, mask, xyxyxyxy + + # Generate masks from boxes + if mask is None: + try: + masks_list = [] + for box in xyxy: + box_mask = generate_mask_from_box(box, (W, H)) + if box_mask is not None: + masks_list.append(box_mask) + mask = np.stack(masks_list) if masks_list else None + logger.debug(f"Generated {len(masks_list)} masks from boxes for {task}") + except Exception as e: + logger.error(f"Error generating masks from boxes: {e}") + mask = None + + elif task == "": + # Handle OCR with special oriented boxes + if xyxyxyxy is not None: + logger.debug(f"Processing OCR with oriented boxes: {len(xyxyxyxy)} regions") + + elif task in ["", ""]: + # These tasks return a single region with description + if labels is not None and len(labels) > 0: + logger.debug(f"Region description: {labels[0]}") + + return xyxy, labels, mask, xyxyxyxy + + async def run_florence(self, images, task, task_prompt, llm_model, precision, attention, + fill_mask, output_mask_select, keep_alive, max_new_tokens, + temperature, top_p, top_k, repetition_penalty, seed, text_input): + try: + # Validate task and format prompt + task_key = self.validate_task(task_prompt) + prompt = f"{task_key} {text_input}" if text_input else task_key + logger.debug(f"Using task: {task_key} with prompt: {prompt}") + + if len(images.shape) == 3: + images = images.unsqueeze(0) + images = images.permute(0, 3, 1, 2) + + out = [] + out_masks = [] + out_results = [] + out_data = [] + + for img in images: + try: + image_pil = TF.to_pil_image(img) + W, H = image_pil.size + + result = await self.model_manager.send_transformers_request( + model_name=llm_model, + system_message="", + user_message=prompt, + messages=[], + max_new_tokens=max_new_tokens, + images=[image_pil], + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_strings_list=["<|endoftext|>"], + repetition_penalty=repetition_penalty, + seed=seed, + keep_alive=keep_alive, + precision=precision, + attention=attention + ) + + generated_text = result[0] + response = result[1][0] if isinstance(result[1], list) else result[1] + + # Process Florence output with task-specific handling + xyxy, labels, mask, xyxyxyxy = self.handle_task_specific_processing( + task_key, response, W, H + ) + + # Generate masks for bounding boxes if no mask was provided + if mask is None and xyxy is not None and len(xyxy) > 0: + try: + masks_list = [] + for box in xyxy: + box_mask = generate_mask_from_box(box, (W, H)) + if box_mask is not None: + masks_list.append(box_mask) + mask = np.stack(masks_list) if masks_list else None + logger.debug(f"Generated {len(masks_list)} masks from boxes") + except Exception as e: + logger.error(f"Error generating masks from boxes: {e}") + mask = None + + # Create detections object + detections = sv.Detections( + xyxy=xyxy, + mask=mask, + class_id=np.arange(len(labels)) if labels is not None else None, + data={"class_name": labels} if labels is not None else None + ) + + # Create annotated image + annotated_frame = np.array(image_pil) + + # Process and apply masks + if mask is not None: + # Handle mask selection + if output_mask_select: + selected_indices = [] + selections = output_mask_select.split(',') + + for sel in selections: + sel = sel.strip().lower() + # Check for label match + if labels is not None: + for idx, label in enumerate(labels): + if sel in label.lower(): + selected_indices.append(idx) + # Check for numeric index + elif sel.isdigit(): + idx = int(sel) + if 0 <= idx < len(mask): + selected_indices.append(idx) + + if selected_indices: + selected_mask = np.zeros_like(mask[0]) + for idx in selected_indices: + selected_mask = np.logical_or(selected_mask, mask[idx]) + mask = np.array([selected_mask]) + + if fill_mask: + detections.mask = mask + annotated_frame = self.mask_annotator.annotate( + scene=annotated_frame, + detections=detections + ) + + # Convert mask for output + processed_mask = process_mask(mask, (W, H)) + if processed_mask is not None: + out_masks.append(processed_mask) + + # Draw boxes and labels + annotated_frame = self.box_annotator.annotate( + scene=annotated_frame, + detections=detections + ) + + if labels is not None: + formatted_labels = [] + for idx, label in enumerate(labels): + if output_mask_select: + if str(idx) in output_mask_select.split(",") or \ + any(sel.lower() in label.lower() for sel in output_mask_select.split(",")): + formatted_labels.append(f"[{idx}] {label}") + else: + formatted_labels.append(label) + else: + formatted_labels.append(f"{label}") + + annotated_frame = self.label_annotator.annotate( + scene=annotated_frame, + detections=detections, + labels=formatted_labels + ) + + # Convert to tensor + annotated_frame = Image.fromarray(annotated_frame) + out_tensor = TF.to_tensor(annotated_frame).unsqueeze(0).permute(0, 2, 3, 1).cpu().float() + out.append(out_tensor) + + # Store results + out_results.append(generated_text) + out_data.append(self.format_output_data(detections, labels, mask, W, H, task)) + + except Exception as e: + logger.error(f"Error processing image: {str(e)}") + continue + + # Combine outputs + if len(out) > 0: + out_tensor = torch.cat(out, dim=0) + if len(out_masks) > 0: + masks_tensor = torch.cat(out_masks, dim=0) + else: + masks_tensor = torch.zeros((1, out_tensor.shape[1], out_tensor.shape[2]), dtype=torch.float32) + else: + out_tensor = torch.zeros((1, 64, 64, 3), dtype=torch.float32) + masks_tensor = torch.zeros((1, 64, 64), dtype=torch.float32) + + if not keep_alive: + self.model_manager.unload_model(llm_model) + + return { + "Question": text_input, + "Response": out_results[0] if len(out_results) == 1 else out_results, + "Negative": "", + "Tool_Output": self.prepare_json_output(out_data), + "Retrieved_Image": out_tensor, + "Mask": masks_tensor + } + + except Exception as e: + logger.error(f"Error in run_florence: {str(e)}", exc_info=True) + # Return valid tensors even in case of error + error_data = {"error": str(e)} + return { + "Question": text_input, + "Response": f"Error: {str(e)}", + "Negative": "", + "Tool_Output": json.dumps(error_data), + "Retrieved_Image": torch.zeros((1, 64, 64, 3), dtype=torch.float32), + "Mask": torch.zeros((1, 64, 64), dtype=torch.float32) } \ No newline at end of file diff --git a/textgen_api.py b/textgen_api.py index f60b2f1..1efb933 100644 --- a/textgen_api.py +++ b/textgen_api.py @@ -1,173 +1,173 @@ -#textgen_api.py -import requests -import json -from typing import List, Union, Optional -import aiohttp -import asyncio -import logging -logger = logging.getLogger(__name__) - - -def create_openai_compatible_embedding(api_base: str, model: str, input: Union[str, List[str]], api_key: Optional[str] = None) -> List[float]: - """ - Create embeddings using an OpenAI-compatible API. - - :param api_base: The base URL for the API - :param model: The name of the model to use for embeddings - :param input: A string or list of strings to embed - :param api_key: The API key (if required) - :return: A list of embeddings - """ - # Normalize the API base URL - api_base = api_base.rstrip('/') - if not api_base.endswith('/v1'): - api_base += '/v1' - - url = f"{api_base}/embeddings" - - headers = { - "Content-Type": "application/json" - } - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - payload = { - "model": model, - "input": input, - "encoding_format": "float" - } - - try: - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - result = response.json() - - if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]: - # If multiple embeddings are returned, we'll just use the first one - return result["data"][0]["embedding"] - else: - raise ValueError("Unexpected response format: 'embedding' data not found") - except requests.RequestException as e: - raise RuntimeError(f"Error calling embedding API: {str(e)}") - -async def send_textgen_request(api_url, base64_images, model, system_message, user_message, messages, seed, temperature, - max_tokens, top_k, top_p, repeat_penalty, stop, tools=None, tool_choice=None): - headers = { - "Content-Type": "application/json" - } - - data = { - "model": model, - "messages": prepare_textgen_messages(system_message, user_message, messages, base64_images), - "temperature": temperature, - "max_tokens": max_tokens, - "presence_penalty": repeat_penalty, - "top_p": top_p, - "top_k": top_k, - "seed": seed - } - - if stop: - data["stop"] = stop - if tools: - data["functions"] = tools - if tool_choice: - data["function_call"] = tool_choice - - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - response.raise_for_status() - response_data = await response.json() - - choices = response_data.get('choices', []) - if choices: - choice = choices[0] - message = choice.get('message', {}) - if "function_call" in message: - return { - "choices": [{ - "message": { - "function_call": { - "name": message["function_call"]["name"], - "arguments": message["function_call"]["arguments"] - } - } - }] - } - else: - generated_text = message.get('content', '') - return { - "choices": [{ - "message": { - "content": generated_text - } - }] - } - else: - error_msg = "Error: No valid choices in the textgen response." - print(error_msg) - return {"choices": [{"message": {"content": error_msg}}]} - except aiohttp.ClientError as e: - error_msg = f"Error in textgen API request: {e}" - print(error_msg) - return {"choices": [{"message": {"content": error_msg}}]} - -def prepare_textgen_messages(system_message, user_message, messages, base64_image=None): - textgen_messages = [] - - if system_message: - textgen_messages.append({"role": "system", "content": system_message}) - - for message in messages: - role = message["role"] - content = message["content"] - - if isinstance(content, list): - # Handle multi-modal content - message_content = [] - for item in content: - if item["type"] == "text": - message_content.append({"type": "text", "text": item["text"]}) - elif item["type"] == "image_url": - message_content.append({ - "type": "image_url", - "image_url": {"url": item["image_url"]["url"]} - }) - textgen_messages.append({"role": role, "content": message_content}) - else: - textgen_messages.append({"role": role, "content": content}) - - # Add the current user message with image if provided - if base64_image: - textgen_messages.append({ - "role": "user", - "content": [ - {"type": "text", "text": user_message}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - } - ] - }) - else: - textgen_messages.append({"role": "user", "content": user_message}) - - return textgen_messages - -def parse_function_call(response, tools): - try: - # Look for JSON-like structure in the response - start = response.find("{") - end = response.rfind("}") + 1 - if start != -1 and end != -1: - json_str = response[start:end] - parsed = json.loads(json_str) - if "function_call" in parsed: - return parsed - except json.JSONDecodeError: - pass - +#textgen_api.py +import requests +import json +from typing import List, Union, Optional +import aiohttp +import asyncio +import logging +logger = logging.getLogger(__name__) + + +def create_openai_compatible_embedding(api_base: str, model: str, input: Union[str, List[str]], api_key: Optional[str] = None) -> List[float]: + """ + Create embeddings using an OpenAI-compatible API. + + :param api_base: The base URL for the API + :param model: The name of the model to use for embeddings + :param input: A string or list of strings to embed + :param api_key: The API key (if required) + :return: A list of embeddings + """ + # Normalize the API base URL + api_base = api_base.rstrip('/') + if not api_base.endswith('/v1'): + api_base += '/v1' + + url = f"{api_base}/embeddings" + + headers = { + "Content-Type": "application/json" + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + payload = { + "model": model, + "input": input, + "encoding_format": "float" + } + + try: + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + result = response.json() + + if "data" in result and len(result["data"]) > 0 and "embedding" in result["data"][0]: + # If multiple embeddings are returned, we'll just use the first one + return result["data"][0]["embedding"] + else: + raise ValueError("Unexpected response format: 'embedding' data not found") + except requests.RequestException as e: + raise RuntimeError(f"Error calling embedding API: {str(e)}") + +async def send_textgen_request(api_url, base64_images, model, system_message, user_message, messages, seed, temperature, + max_tokens, top_k, top_p, repeat_penalty, stop, tools=None, tool_choice=None): + headers = { + "Content-Type": "application/json" + } + + data = { + "model": model, + "messages": prepare_textgen_messages(system_message, user_message, messages, base64_images), + "temperature": temperature, + "max_tokens": max_tokens, + "presence_penalty": repeat_penalty, + "top_p": top_p, + "top_k": top_k, + "seed": seed + } + + if stop: + data["stop"] = stop + if tools: + data["functions"] = tools + if tool_choice: + data["function_call"] = tool_choice + + try: + async with aiohttp.ClientSession() as session: + async with session.post(api_url, headers=headers, json=data) as response: + response.raise_for_status() + response_data = await response.json() + + choices = response_data.get('choices', []) + if choices: + choice = choices[0] + message = choice.get('message', {}) + if "function_call" in message: + return { + "choices": [{ + "message": { + "function_call": { + "name": message["function_call"]["name"], + "arguments": message["function_call"]["arguments"] + } + } + }] + } + else: + generated_text = message.get('content', '') + return { + "choices": [{ + "message": { + "content": generated_text + } + }] + } + else: + error_msg = "Error: No valid choices in the textgen response." + print(error_msg) + return {"choices": [{"message": {"content": error_msg}}]} + except aiohttp.ClientError as e: + error_msg = f"Error in textgen API request: {e}" + print(error_msg) + return {"choices": [{"message": {"content": error_msg}}]} + +def prepare_textgen_messages(system_message, user_message, messages, base64_image=None): + textgen_messages = [] + + if system_message: + textgen_messages.append({"role": "system", "content": system_message}) + + for message in messages: + role = message["role"] + content = message["content"] + + if isinstance(content, list): + # Handle multi-modal content + message_content = [] + for item in content: + if item["type"] == "text": + message_content.append({"type": "text", "text": item["text"]}) + elif item["type"] == "image_url": + message_content.append({ + "type": "image_url", + "image_url": {"url": item["image_url"]["url"]} + }) + textgen_messages.append({"role": role, "content": message_content}) + else: + textgen_messages.append({"role": role, "content": content}) + + # Add the current user message with image if provided + if base64_image: + textgen_messages.append({ + "role": "user", + "content": [ + {"type": "text", "text": user_message}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + } + ] + }) + else: + textgen_messages.append({"role": "user", "content": user_message}) + + return textgen_messages + +def parse_function_call(response, tools): + try: + # Look for JSON-like structure in the response + start = response.find("{") + end = response.rfind("}") + 1 + if start != -1 and end != -1: + json_str = response[start:end] + parsed = json.loads(json_str) + if "function_call" in parsed: + return parsed + except json.JSONDecodeError: + pass + return None \ No newline at end of file diff --git a/transformers_api.py b/transformers_api.py index f468c70..5e3ff74 100644 --- a/transformers_api.py +++ b/transformers_api.py @@ -1,348 +1,348 @@ -# transformers_api.py -from transformers import ( - Qwen2VLForConditionalGeneration, - Qwen2VLProcessor, - AutoConfig, - AutoModelForCausalLM, - AutoProcessor, - BitsAndBytesConfig, - GenerationConfig, - StopStringCriteria, - set_seed, -) -from typing import List, Union, Optional, Dict, Any -from PIL import Image -from io import BytesIO -import base64 -import torch -import logging -import os -import re -from folder_paths import models_dir -from unittest.mock import patch -from transformers.dynamic_module_utils import get_imports -import json -import importlib -import importlib.util -import comfy.model_management as mm -from torchvision.transforms import functional as TF - -logger = logging.getLogger(__name__) - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - - -class TransformersModelManager: - def __init__(self): - self.models_dir = models_dir - self.models = {} - self.processors = {} - self.loaded_models = {} - self.device = mm.get_torch_device() - self.offload_device = mm.unet_offload_device() - self.model_path = None - self.model_load_args = { - "device_map": self.device, - "torch_dtype": "auto", - "trust_remote_code": True - } - - def download_model_if_not_exists(self, model_name): - from huggingface_hub import snapshot_download - - model_dir = model_name.rsplit('/', 1)[-1] - model_path = os.path.join(self.models_dir, "LLM", model_dir) - if not os.path.exists(model_path): - logger.info(f"Downloading model '{model_name}' to: {model_path}") - try: - snapshot_download( - repo_id=model_name, - local_dir=model_path, - local_dir_use_symlinks=False, - token=os.getenv("HUGGINGFACE_TOKEN") or "" - ) - logger.info(f"Model '{model_name}' downloaded successfully.") - except Exception as e: - logger.error(f"An error occurred while downloading the model '{model_name}': {e}") - return None - else: - logger.info(f"Model '{model_name}' already exists at: {model_path}") - return model_path - - def hash_seed(self, seed): - import hashlib - seed_bytes = str(seed).encode('utf-8') - hash_object = hashlib.sha256(seed_bytes) - hashed_seed = int(hash_object.hexdigest(), 16) - return hashed_seed % (2**32) - - def load_model(self, model: str, precision: str, attention: str) -> Optional[Dict[str, Any]]: - if model in self.loaded_models: - logger.info(f"Model '{model}' already loaded and cached.") - return self.loaded_models[model] - - if precision == "int8": - quant_config = BitsAndBytesConfig(load_in_8bit=True) - dtype = torch.bfloat16 if 'mpt' in model.lower() or 'llama2' in model.lower() else torch.float16 - elif precision == "int4": - quant_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - bnb_4bit_compute_dtype=torch.bfloat16 - ) - dtype = torch.bfloat16 - else: - quant_config = None - dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}.get(precision, torch.float16) - - model_path = self.download_model_if_not_exists(model) - if model_path is None: - logger.error(f"Model path for '{model}' could not be determined.") - return None - - config_path = os.path.join(model_path, "config.json") - if not os.path.exists(config_path): - logger.error(f"Config file not found at: {config_path}") - return None - - device = self.device - try: - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - architectures = config.architectures - - if architectures and isinstance(architectures, list) and len(architectures) > 0: - model_class = architectures[0] - try: - common_args = { - "pretrained_model_name_or_path": model_path, - "attn_implementation": attention, - "torch_dtype": dtype, - "trust_remote_code": True, - "device_map": device, - } - - if quant_config: - common_args["quantization_config"] = quant_config - - if "florence" in model.lower() or 'florence' in model_path.lower() or "deepseek" in model.lower() or 'deepseek' in model_path.lower(): - with patch("transformers.dynamic_module_utils.get_imports", self.fixed_get_imports): - loaded_model = AutoModelForCausalLM.from_pretrained(**common_args) - elif "pixtral" in model.lower(): - from transformers import LlavaForConditionalGeneration - loaded_model = LlavaForConditionalGeneration.from_pretrained(**common_args, use_safetensors=True) - elif "molmo" in model.lower(): - loaded_model = AutoModelForCausalLM.from_pretrained(**common_args, use_safetensors=True) - elif "qwen2-vl" in model.lower(): - min_pixels = 224 * 224 - max_pixels = 1024 * 1024 - processor = Qwen2VLProcessor.from_pretrained( - model_path, - min_pixels=min_pixels, - max_pixels=max_pixels, - trust_remote_code=True - ) - loaded_model = Qwen2VLForConditionalGeneration.from_pretrained(**common_args, use_safetensors=True) - else: - loaded_model = model_class.from_pretrained(**common_args) - - processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - - except AttributeError: - logger.warning(f"AttributeError encountered. Forcing trust_remote_code=True for model: {model}") - loaded_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map=device) - processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - - except Exception as e: - logger.error(f"Error loading model from config.json: {e}") - return None - - self.loaded_models[model] = {'model': loaded_model, 'processor': processor, 'dtype': dtype} - logger.info(f"Model '{model}' loaded successfully and cached.") - return self.loaded_models[model] - - async def send_transformers_request( - self, - model_name, - system_message, - user_message, - messages, - max_new_tokens, - images, - temperature, - top_p, - top_k, - stop_strings_list, - repetition_penalty, - seed, - keep_alive=True, - precision="fp16", - attention="sdpa", - ): - try: - if model_name in self.loaded_models: - logger.info(f"Model '{model_name}' already loaded and cached.") - model_data = self.loaded_models[model_name] - else: - model_data = self.load_model(model_name, precision=precision, attention=attention) - if model_data is None: - raise ValueError(f"Failed to load model '{model_name}'.") - - model = model_data['model'] - processor = model_data['processor'] - tokenizer = processor.tokenizer - dtype = model_data['dtype'] - - if seed is not None: - logger.info(f"Setting seed: {seed}") - set_seed(self.hash_seed(seed)) - - # Convert to PIL Images if necessary - pil_images = [] - if isinstance(images, torch.Tensor): - images = images.permute(0, 3, 1, 2) - for img in images: - pil_images.append(TF.to_pil_image(img)) - elif isinstance(images, list) and all(isinstance(img, Image.Image) for img in images): - pil_images = images - else: - raise ValueError("Images must be either a torch.Tensor or a list of PIL Images") - - logger.debug(f"Number of images processed: {len(pil_images)}") - - # Construct standardized messages - formatted_messages = self.construct_messages(system_message, user_message, messages, pil_images) - - logger.debug(f"Formatted messages: {formatted_messages}") - - if 'florence' in model_name.lower(): - # Process input for Florence models - generated_texts = [] - responses = [] - images_pil = [] - for pil_image in pil_images: - inputs = processor(images=[pil_image], text=user_message, return_tensors="pt", do_rescale=False).to(dtype).to(model.device) - - logger.debug(f"Inputs shape: {inputs['pixel_values'].shape}, dtype: {inputs['pixel_values'].dtype}") - logger.debug(f"Input IDs shape: {inputs['input_ids'].shape}, dtype: {inputs['input_ids'].dtype}") - - with torch.random.fork_rng(devices=[model.device]): - torch.random.manual_seed(seed) - - try: - generated_ids = model.generate( - **inputs, - max_new_tokens=max_new_tokens, - num_beams=3, - do_sample=True, - temperature=temperature, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, - ) - except Exception as e: - logger.error(f"Error during model.generate: {e}") - raise - - results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] - generated_text = self.clean_results(results, user_message) - response = processor.post_process_generation(generated_text, task=user_message, image_size=pil_image.size) - generated_texts.append(generated_text) - responses.append(response) - # images_pil.append(pil_image) - - result = (generated_texts, responses) - else: - # Handle other transformers models - inputs = processor(formatted_messages, return_tensors="pt", padding=True).to(model.device) - - # Convert inputs to the correct dtype - inputs = {k: v.to(dtype=torch.long if v.dtype == torch.int64 else dtype) if torch.is_tensor(v) else v for k, v in inputs.items()} - - with torch.no_grad(): - try: - outputs = model.generate( - **inputs, - generation_config=GenerationConfig( - max_new_tokens=max_new_tokens, - do_sample=True, - temperature=temperature, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, - ), - stopping_criteria=[StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings_list)], - ) - except Exception as e: - logger.error(f"Error during model.generate: {e}") - raise - - generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] - - result = generated_text - - if not keep_alive: - self.unload_model(model_name) - - return result - - except Exception as e: - logger.error(f"Error in Transformers API request: {e}", exc_info=True) - return str(e) - - def clean_results(self, results, task): - if task == 'ocr_with_region': - clean_results = re.sub(r'|<[^>]*>', '\n', results) - clean_results = re.sub(r'\n+', '\n', clean_results) - else: - clean_results = results.replace('', '').replace('', '') - return clean_results - - def construct_messages(self, system_message, user_message, messages, pil_images): - """Constructs a standardized message format for transformer models.""" - formatted_messages = [] - if system_message: - formatted_messages.append({"role": "system", "content": system_message}) - - for msg in messages: - formatted_messages.append({"role": msg['role'], "content": msg['content']}) - - if user_message: - formatted_messages.append({ - "role": "user", - "content": [ - {"type": "text", "text": user_message}, - *[{"type": "image", "image": img} for img in pil_images] - ] - }) - - return formatted_messages - - def unload_model(self, model_name: str): - print(f"Offloading model: {model_name}") - if model_name in self.loaded_models: - model = self.loaded_models[model_name]['model'] - model.to(self.offload_device) - del self.loaded_models[model_name] - mm.soft_empty_cache() - else: - print(f"Model {model_name} not found in loaded models.") - - @classmethod - def fixed_get_imports(cls, filename: Union[str, os.PathLike], *args, **kwargs) -> List[str]: - """Remove 'flash_attn' from imports if present.""" - try: - if not str(filename).endswith("modeling_florence2.py") or not str(filename).endswith("modeling_deepseek.py"): - return get_imports(filename) - imports = get_imports(filename) - if "flash_attn" in imports: - imports.remove("flash_attn") - return imports - except Exception as e: - print(f"No flash_attn import to remove: {e}") - return get_imports(filename) - - -# Initialize a global manager instance -_transformers_manager = TransformersModelManager() +# transformers_api.py +from transformers import ( + Qwen2VLForConditionalGeneration, + Qwen2VLProcessor, + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + BitsAndBytesConfig, + GenerationConfig, + StopStringCriteria, + set_seed, +) +from typing import List, Union, Optional, Dict, Any +from PIL import Image +from io import BytesIO +import base64 +import torch +import logging +import os +import re +from folder_paths import models_dir +from unittest.mock import patch +from transformers.dynamic_module_utils import get_imports +import json +import importlib +import importlib.util +import comfy.model_management as mm +from torchvision.transforms import functional as TF + +logger = logging.getLogger(__name__) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TransformersModelManager: + def __init__(self): + self.models_dir = models_dir + self.models = {} + self.processors = {} + self.loaded_models = {} + self.device = mm.get_torch_device() + self.offload_device = mm.unet_offload_device() + self.model_path = None + self.model_load_args = { + "device_map": self.device, + "torch_dtype": "auto", + "trust_remote_code": True + } + + def download_model_if_not_exists(self, model_name): + from huggingface_hub import snapshot_download + + model_dir = model_name.rsplit('/', 1)[-1] + model_path = os.path.join(self.models_dir, "LLM", model_dir) + if not os.path.exists(model_path): + logger.info(f"Downloading model '{model_name}' to: {model_path}") + try: + snapshot_download( + repo_id=model_name, + local_dir=model_path, + local_dir_use_symlinks=False, + token=os.getenv("HUGGINGFACE_TOKEN") or "" + ) + logger.info(f"Model '{model_name}' downloaded successfully.") + except Exception as e: + logger.error(f"An error occurred while downloading the model '{model_name}': {e}") + return None + else: + logger.info(f"Model '{model_name}' already exists at: {model_path}") + return model_path + + def hash_seed(self, seed): + import hashlib + seed_bytes = str(seed).encode('utf-8') + hash_object = hashlib.sha256(seed_bytes) + hashed_seed = int(hash_object.hexdigest(), 16) + return hashed_seed % (2**32) + + def load_model(self, model: str, precision: str, attention: str) -> Optional[Dict[str, Any]]: + if model in self.loaded_models: + logger.info(f"Model '{model}' already loaded and cached.") + return self.loaded_models[model] + + if precision == "int8": + quant_config = BitsAndBytesConfig(load_in_8bit=True) + dtype = torch.bfloat16 if 'mpt' in model.lower() or 'llama2' in model.lower() else torch.float16 + elif precision == "int4": + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16 + ) + dtype = torch.bfloat16 + else: + quant_config = None + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}.get(precision, torch.float16) + + model_path = self.download_model_if_not_exists(model) + if model_path is None: + logger.error(f"Model path for '{model}' could not be determined.") + return None + + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + logger.error(f"Config file not found at: {config_path}") + return None + + device = self.device + try: + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + architectures = config.architectures + + if architectures and isinstance(architectures, list) and len(architectures) > 0: + model_class = architectures[0] + try: + common_args = { + "pretrained_model_name_or_path": model_path, + "attn_implementation": attention, + "torch_dtype": dtype, + "trust_remote_code": True, + "device_map": device, + } + + if quant_config: + common_args["quantization_config"] = quant_config + + if "florence" in model.lower() or 'florence' in model_path.lower() or "deepseek" in model.lower() or 'deepseek' in model_path.lower(): + with patch("transformers.dynamic_module_utils.get_imports", self.fixed_get_imports): + loaded_model = AutoModelForCausalLM.from_pretrained(**common_args) + elif "pixtral" in model.lower(): + from transformers import LlavaForConditionalGeneration + loaded_model = LlavaForConditionalGeneration.from_pretrained(**common_args, use_safetensors=True) + elif "molmo" in model.lower(): + loaded_model = AutoModelForCausalLM.from_pretrained(**common_args, use_safetensors=True) + elif "qwen2-vl" in model.lower(): + min_pixels = 224 * 224 + max_pixels = 1024 * 1024 + processor = Qwen2VLProcessor.from_pretrained( + model_path, + min_pixels=min_pixels, + max_pixels=max_pixels, + trust_remote_code=True + ) + loaded_model = Qwen2VLForConditionalGeneration.from_pretrained(**common_args, use_safetensors=True) + else: + loaded_model = model_class.from_pretrained(**common_args) + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + except AttributeError: + logger.warning(f"AttributeError encountered. Forcing trust_remote_code=True for model: {model}") + loaded_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map=device) + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + except Exception as e: + logger.error(f"Error loading model from config.json: {e}") + return None + + self.loaded_models[model] = {'model': loaded_model, 'processor': processor, 'dtype': dtype} + logger.info(f"Model '{model}' loaded successfully and cached.") + return self.loaded_models[model] + + async def send_transformers_request( + self, + model_name, + system_message, + user_message, + messages, + max_new_tokens, + images, + temperature, + top_p, + top_k, + stop_strings_list, + repetition_penalty, + seed, + keep_alive=True, + precision="fp16", + attention="sdpa", + ): + try: + if model_name in self.loaded_models: + logger.info(f"Model '{model_name}' already loaded and cached.") + model_data = self.loaded_models[model_name] + else: + model_data = self.load_model(model_name, precision=precision, attention=attention) + if model_data is None: + raise ValueError(f"Failed to load model '{model_name}'.") + + model = model_data['model'] + processor = model_data['processor'] + tokenizer = processor.tokenizer + dtype = model_data['dtype'] + + if seed is not None: + logger.info(f"Setting seed: {seed}") + set_seed(self.hash_seed(seed)) + + # Convert to PIL Images if necessary + pil_images = [] + if isinstance(images, torch.Tensor): + images = images.permute(0, 3, 1, 2) + for img in images: + pil_images.append(TF.to_pil_image(img)) + elif isinstance(images, list) and all(isinstance(img, Image.Image) for img in images): + pil_images = images + else: + raise ValueError("Images must be either a torch.Tensor or a list of PIL Images") + + logger.debug(f"Number of images processed: {len(pil_images)}") + + # Construct standardized messages + formatted_messages = self.construct_messages(system_message, user_message, messages, pil_images) + + logger.debug(f"Formatted messages: {formatted_messages}") + + if 'florence' in model_name.lower(): + # Process input for Florence models + generated_texts = [] + responses = [] + images_pil = [] + for pil_image in pil_images: + inputs = processor(images=[pil_image], text=user_message, return_tensors="pt", do_rescale=False).to(dtype).to(model.device) + + logger.debug(f"Inputs shape: {inputs['pixel_values'].shape}, dtype: {inputs['pixel_values'].dtype}") + logger.debug(f"Input IDs shape: {inputs['input_ids'].shape}, dtype: {inputs['input_ids'].dtype}") + + with torch.random.fork_rng(devices=[model.device]): + torch.random.manual_seed(seed) + + try: + generated_ids = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + num_beams=3, + do_sample=True, + temperature=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + ) + except Exception as e: + logger.error(f"Error during model.generate: {e}") + raise + + results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + generated_text = self.clean_results(results, user_message) + response = processor.post_process_generation(generated_text, task=user_message, image_size=pil_image.size) + generated_texts.append(generated_text) + responses.append(response) + # images_pil.append(pil_image) + + result = (generated_texts, responses) + else: + # Handle other transformers models + inputs = processor(formatted_messages, return_tensors="pt", padding=True).to(model.device) + + # Convert inputs to the correct dtype + inputs = {k: v.to(dtype=torch.long if v.dtype == torch.int64 else dtype) if torch.is_tensor(v) else v for k, v in inputs.items()} + + with torch.no_grad(): + try: + outputs = model.generate( + **inputs, + generation_config=GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + ), + stopping_criteria=[StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings_list)], + ) + except Exception as e: + logger.error(f"Error during model.generate: {e}") + raise + + generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] + + result = generated_text + + if not keep_alive: + self.unload_model(model_name) + + return result + + except Exception as e: + logger.error(f"Error in Transformers API request: {e}", exc_info=True) + return str(e) + + def clean_results(self, results, task): + if task == 'ocr_with_region': + clean_results = re.sub(r'|<[^>]*>', '\n', results) + clean_results = re.sub(r'\n+', '\n', clean_results) + else: + clean_results = results.replace('', '').replace('', '') + return clean_results + + def construct_messages(self, system_message, user_message, messages, pil_images): + """Constructs a standardized message format for transformer models.""" + formatted_messages = [] + if system_message: + formatted_messages.append({"role": "system", "content": system_message}) + + for msg in messages: + formatted_messages.append({"role": msg['role'], "content": msg['content']}) + + if user_message: + formatted_messages.append({ + "role": "user", + "content": [ + {"type": "text", "text": user_message}, + *[{"type": "image", "image": img} for img in pil_images] + ] + }) + + return formatted_messages + + def unload_model(self, model_name: str): + print(f"Offloading model: {model_name}") + if model_name in self.loaded_models: + model = self.loaded_models[model_name]['model'] + model.to(self.offload_device) + del self.loaded_models[model_name] + mm.soft_empty_cache() + else: + print(f"Model {model_name} not found in loaded models.") + + @classmethod + def fixed_get_imports(cls, filename: Union[str, os.PathLike], *args, **kwargs) -> List[str]: + """Remove 'flash_attn' from imports if present.""" + try: + if not str(filename).endswith("modeling_florence2.py") or not str(filename).endswith("modeling_deepseek.py"): + return get_imports(filename) + imports = get_imports(filename) + if "flash_attn" in imports: + imports.remove("flash_attn") + return imports + except Exception as e: + print(f"No flash_attn import to remove: {e}") + return get_imports(filename) + + +# Initialize a global manager instance +_transformers_manager = TransformersModelManager() diff --git a/utils.py b/utils.py index 1a98e38..ce0e5c4 100644 --- a/utils.py +++ b/utils.py @@ -396,6 +396,52 @@ def process_mask(retrieved_mask, image_tensor): # Return a default mask matching the image dimensions return torch.ones((image_tensor.shape[0], image_tensor.shape[2], image_tensor.shape[3]), dtype=torch.float32) +def convert_mask_to_grayscale_alpha(mask_input): + """ + Convert mask to grayscale alpha channel. + Handles tensors, PIL images and numpy arrays. + Returns tensor in shape [B,1,H,W]. + """ + if isinstance(mask_input, torch.Tensor): + # Handle tensor input + if mask_input.dim() == 2: # [H,W] + return mask_input.unsqueeze(0).unsqueeze(0) # Add batch and channel dims + elif mask_input.dim() == 3: # [C,H,W] or [B,H,W] + if mask_input.shape[0] in [1,3,4]: # Assume channel-first + if mask_input.shape[0] == 4: # Use alpha channel + return mask_input[3:4].unsqueeze(0) + else: # Convert to grayscale + weights = torch.tensor([0.299, 0.587, 0.114]).to(mask_input.device) + return (mask_input * weights.view(-1,1,1)).sum(0).unsqueeze(0).unsqueeze(0) + else: # Assume batch dimension + return mask_input.unsqueeze(1) # Add channel dim + elif mask_input.dim() == 4: # [B,C,H,W] + if mask_input.shape[1] == 4: # Use alpha channel + return mask_input[:,3:4] + else: # Convert to grayscale + weights = torch.tensor([0.299, 0.587, 0.114]).to(mask_input.device) + return (mask_input * weights.view(1,-1,1,1)).sum(1).unsqueeze(1) + + elif isinstance(mask_input, Image.Image): + # Convert PIL image to grayscale + mask = mask_input.convert('L') + tensor = torch.from_numpy(np.array(mask)).float() / 255.0 + return tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dims + + elif isinstance(mask_input, np.ndarray): + # Handle numpy array + if mask_input.ndim == 2: # [H,W] + tensor = torch.from_numpy(mask_input).float() + return tensor.unsqueeze(0).unsqueeze(0) + elif mask_input.ndim == 3: # [H,W,C] + if mask_input.shape[2] == 4: # Use alpha channel + tensor = torch.from_numpy(mask_input[:,:,3]).float() + else: # Convert to grayscale + tensor = torch.from_numpy(np.dot(mask_input[...,:3], [0.299, 0.587, 0.114])).float() + return tensor.unsqueeze(0).unsqueeze(0) + + raise ValueError(f"Unsupported mask input type: {type(mask_input)}") + def tensor_to_base64(tensor: torch.Tensor) -> str: """Convert a tensor to a base64-encoded PNG image string.""" try: