diff --git a/src/config/configuration.py b/src/config/configuration.py index e7845d58b..defdb4acf 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -42,6 +42,7 @@ def get_recursion_limit(default: int = 25) -> int: class Configuration: """The configurable fields.""" + thread_id: str = field(default="") resources: list[Resource] = field( default_factory=list ) # Resources to be used for the research diff --git a/src/graph/checkpoint.py b/src/graph/checkpoint.py index 9cec30aba..312afa3ac 100644 --- a/src/graph/checkpoint.py +++ b/src/graph/checkpoint.py @@ -5,7 +5,8 @@ import logging import uuid from datetime import datetime -from typing import List, Optional, Tuple + +from typing import List, Optional, Tuple, cast import psycopg from langgraph.store.memory import InMemoryStore @@ -86,6 +87,8 @@ def _init_postgresql(self) -> None: self.postgres_conn = psycopg.connect(self.db_uri, row_factory=dict_row) self.logger.info("Successfully connected to PostgreSQL") self._create_chat_streams_table() + self._create_langgraph_events_table() + self._create_research_replays_table() except Exception as e: self.logger.error(f"Failed to connect to PostgreSQL: {e}") @@ -112,6 +115,319 @@ def _create_chat_streams_table(self) -> None: if self.postgres_conn: self.postgres_conn.rollback() + def _create_langgraph_events_table(self) -> None: + """Create the langgraph_events table if it doesn't exist.""" + try: + with self.postgres_conn.cursor() as cursor: + create_table_sql = """ + CREATE TABLE IF NOT EXISTS langgraph_events ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + thread_id VARCHAR(255) NOT NULL, + event VARCHAR(255) NOT NULL, + level VARCHAR(50) NOT NULL, + message JSONB NOT NULL, + ts TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + CREATE INDEX IF NOT EXISTS idx_langgraph_events_thread_id ON langgraph_events(thread_id); + CREATE INDEX IF NOT EXISTS idx_langgraph_events_ts ON langgraph_events(ts); + """ + cursor.execute(create_table_sql) + self.postgres_conn.commit() + self.logger.info("Langgraph events table created/verified successfully") + except Exception as e: + self.logger.error(f"Failed to create langgraph_events table: {e}") + if self.postgres_conn: + self.postgres_conn.rollback() + + def _create_research_replays_table(self) -> None: + """Create the research_replays table if it doesn't exist.""" + try: + with self.postgres_conn.cursor() as cursor: + create_table_sql = """ + CREATE TABLE IF NOT EXISTS research_replays ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + thread_id VARCHAR(255) NOT NULL, + research_topic VARCHAR(255) NOT NULL, + report_style VARCHAR(50) NOT NULL, + messages INTEGER NOT NULL, + ts TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + CREATE INDEX IF NOT EXISTS idx_research_replays_thread_id ON research_replays(thread_id); + CREATE INDEX IF NOT EXISTS idx_research_replays_ts ON research_replays(ts); + """ + cursor.execute(create_table_sql) + self.postgres_conn.commit() + self.logger.info("Research replays table created/verified successfully") + except Exception as e: + self.logger.error(f"Failed to create research_replays table: {e}") + if self.postgres_conn: + self.postgres_conn.rollback() + + def _process_stream_messages(self, stream_message: dict | str | None) -> str: + if stream_message is None: + return "" + if isinstance(stream_message, str): + # If stream_message is a string, return it directly + return stream_message + if not isinstance(stream_message, dict): + # If stream_message is not a dict, return an empty string + return "" + messages = cast(list, stream_message.get("messages", [])) + # remove the first message which is usually the system prompt + if messages and isinstance(messages, list) and len(messages) > 0: + # Decode byte messages back to strings + decoded_messages = [] + for message in messages: + if isinstance(message, bytes): + decoded_messages.append(message.decode("utf-8")) + else: + decoded_messages.append(str(message)) + # Return all messages except the first one + valid_messages = [] + for message in decoded_messages: + if ( + str(message).find("event:") == -1 + and str(message).find("data:") == -1 + ): + continue + if str(message).find("message_chunk") > -1: + if ( + str(message).find("content") > -1 + or str(message).find("reasoning_content") > -1 + or str(message).find("finish_reason") > -1 + ): + valid_messages.append(message) + + else: + valid_messages.append(message) + return "".join(valid_messages) if valid_messages else "" + elif messages and isinstance(messages, str): + # If messages is a single string, return it directly + return messages + else: + # If no messages found, return an empty string + return "" + + def log_research_replays( + self, thread_id: str, research_topic: str, report_style: str, messages: int + ) -> None: + if not self.checkpoint_saver: + logging.warning( + "Checkpoint saver is disabled, cannot retrieve conversation" + ) + return None + if self.mongo_db is None and self.postgres_conn is None: + logging.warning("No DB connection available") + return None + if self.mongo_db is not None: + try: + collection = self.mongo_db.research_replays + # Update existing conversation with new messages count + if messages > 0: + existing_document = collection.find_one({"thread_id": thread_id}) + if existing_document: + update_result = collection.update_one( + {"thread_id": thread_id}, + { + "$set": { + "messages": messages, + } + }, + ) + self.logger.info( + f"Updated research replay for thread {thread_id}: " + f"{update_result.modified_count} documents modified" + ) + else: + result = collection.insert_one( + { + "thread_id": thread_id, + "research_topic": research_topic, + "report_style": report_style, + "messages": messages, + "ts": datetime.now(), + "id": uuid.uuid4().hex, + } + ) + self.logger.info(f"Event logged: {result.inserted_id}") + except Exception as e: + self.logger.error(f"Error logging event: {e}") + elif self.postgres_conn is not None: + try: + # Update existing conversation with new messages count + if messages > 0: + with self.postgres_conn.cursor() as cursor: + cursor.execute( + "SELECT id FROM research_replays WHERE thread_id = %s", + (thread_id,), + ) + existing_record = cursor.fetchone() + if existing_record: + with self.postgres_conn.cursor() as cursor: + cursor.execute( + """ + UPDATE research_replays + SET messages = %s + WHERE thread_id = %s + """, + (messages, thread_id), + ) + self.postgres_conn.commit() + self.logger.info( + f"Updated research replay for thread {thread_id}: " + f"{cursor.rowcount} rows modified" + ) + else: + with self.postgres_conn.cursor() as cursor: + cursor.execute( + """ + INSERT INTO research_replays (thread_id, research_topic, report_style, messages, ts) + VALUES (%s, %s, %s, %s, %s) + """, + ( + thread_id, + research_topic, + report_style, + messages, + datetime.now(), + ), + ) + self.postgres_conn.commit() + self.logger.info("Research replay logged successfully") + except Exception as e: + self.logger.error(f"Error logging research replay: {e}") + + def log_graph_event( + self, thread_id: str, event: str, level: str, message: dict + ) -> None: + """ + Log an event related to a conversation thread. + Args: + thread_id (str): Unique identifier for the conversation thread + event (str): Event type or name + level (str): Log level (e.g., "info", "warning", "error") + message (dict): Additional message data to log + """ + if not self.checkpoint_saver: + logging.warning( + "Checkpoint saver is disabled, cannot retrieve conversation" + ) + return None + if self.mongo_db is None and self.postgres_conn is None: + logging.warning("No mongodb connection available") + return None + if self.mongo_db is not None: + try: + collection = self.mongo_db.langgraph_events + result = collection.insert_one( + { + "thread_id": thread_id, + "event": event, + "level": level, + "message": message, + "ts": datetime.now(), + "id": uuid.uuid4().hex, + } + ) + self.logger.info(f"Event logged: {result.inserted_id}") + except Exception as e: + self.logger.error(f"Error logging event: {e}") + elif self.postgres_conn is not None: + try: + with self.postgres_conn.cursor() as cursor: + cursor.execute( + """ + INSERT INTO langgraph_events (thread_id, event, level, message, ts) + VALUES (%s, %s, %s, %s, %s) + """, + ( + thread_id, + event, + level, + json.dumps(message), + datetime.now(), + ), + ) + self.postgres_conn.commit() + self.logger.info("Event logged successfully") + except Exception as e: + self.logger.error(f"Error logging event: {e}") + + def get_messages_by_id(self, thread_id: str) -> Optional[str]: + """Retrieve a conversation by thread_id.""" + if not self.checkpoint_saver: + logging.warning( + "Checkpoint saver is disabled, cannot retrieve conversation" + ) + return None + if self.mongo_db is None and self.postgres_conn is None: + logging.warning("No database connection available") + return None + if self.mongo_db is not None: + # MongoDB retrieval + collection = self.mongo_db.chat_streams + conversation = collection.find_one({"thread_id": thread_id}) + if conversation is None: + logging.warning(f"No conversation found for thread_id: {thread_id}") + return None + messages = self._process_stream_messages(conversation) + return messages + elif self.postgres_conn: + # PostgreSQL retrieval + with self.postgres_conn.cursor() as cursor: + cursor.execute( + "SELECT * FROM chat_streams WHERE thread_id = %s", (thread_id,) + ) + conversation = cursor.fetchone() + if conversation is None: + logging.warning(f"No conversation found for thread_id: {thread_id}") + return None + messages = self._process_stream_messages(conversation) + return messages + else: + logging.warning("No database connection available") + return None + + def get_stream_messages(self, limit: int = 10, sort: str = "ts") -> List[dict]: + """ + Retrieve chat stream messages from the database. + Args: + limit (int): Maximum number of messages to retrieve + sort (str): Field to sort by, default is 'ts' (timestamp) + Returns: + List[dict]: List of chat stream messages, sorted by the specified field + """ + if not self.checkpoint_saver: + self.logger.warning( + "Checkpoint saver is disabled, cannot retrieve messages" + ) + return [] + if self.mongo_db is None and self.postgres_conn is None: + self.logger.warning("No database connection available") + return [] + try: + if self.mongo_db is not None: + # MongoDB retrieval + collection = self.mongo_db.research_replays + cursor = collection.find().sort(sort, -1).limit(limit) + messages = list(cursor) if cursor is not None else [] + return messages + elif self.postgres_conn: + # PostgreSQL retrieval + with self.postgres_conn.cursor() as cursor: + query = ( + f"SELECT * FROM research_replays ORDER BY {sort} DESC LIMIT %s" + ) + cursor.execute(query, (limit,)) + messages = cursor.fetchall() + return messages + else: + self.logger.warning("No database connection available") + return [] + except Exception as e: + self.logger.error(f"Error retrieving chat stream messages: {e}") + return [] + def process_stream_message( self, thread_id: str, message: str, finish_reason: str ) -> bool: @@ -208,7 +524,8 @@ def _persist_complete_conversation( if not self.checkpoint_saver: self.logger.warning("Checkpoint saver is disabled") return False - + # Log the event of persisting conversation + self.log_research_replays(thread_id, "", "", len(messages)) # Choose persistence method based on available connection if self.mongo_db is not None: return self._persist_to_mongodb(thread_id, messages) @@ -371,3 +688,44 @@ def chat_stream_message(thread_id: str, message: str, finish_reason: str) -> boo ) else: return False + + +def list_conversations(limit: int, sort: str = "ts"): + checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False) + if checkpoint_saver: + return _default_manager.get_stream_messages(limit, sort) + else: + logging.warning("Checkpoint saver is disabled, message not processed") + return [] + + +def get_conversation(thread_id: str): + """Retrieve a conversation by thread_id.""" + checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False) + if checkpoint_saver: + return _default_manager.get_messages_by_id(thread_id) + else: + logging.warning("Checkpoint saver is disabled, message not processed") + return "" + + +def log_graph_event(thread_id: str, event: str, level: str, message: dict): + checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False) + if checkpoint_saver: + return _default_manager.log_graph_event(thread_id, event, level, message) + else: + logging.warning("Checkpoint saver is disabled, message not processed") + return "" + + +def log_research_replays( + thread_id: str, research_topic: str, report_style: str, messages: int +): + checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False) + if checkpoint_saver: + return _default_manager.log_research_replays( + thread_id, research_topic, report_style, messages + ) + else: + logging.warning("Checkpoint saver is disabled, message not processed") + return "" diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 12dc9bfd8..269b0deb2 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -4,7 +4,7 @@ import json import logging import os -from typing import Annotated, Literal +from typing import Annotated, Literal, cast from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import RunnableConfig @@ -28,6 +28,7 @@ from src.utils.json_utils import repair_json_output from ..config import SELECTED_SEARCH_ENGINE, SearchEngine +from .checkpoint import log_research_replays, log_graph_event from .types import State logger = logging.getLogger(__name__) @@ -58,13 +59,18 @@ def background_investigation_node(state: State, config: RunnableConfig): searched_content = searched_content[0] if isinstance(searched_content, list): background_investigation_results = [ - f"## {elem['title']}\n\n{elem['content']}" for elem in searched_content + f"## {elem['title']}\n\n{elem['content'] }" + for elem in searched_content # if elem.get("type") == "page" ] - return { - "background_investigation_results": "\n\n".join( - background_investigation_results - ) - } + results = "\n\n".join(background_investigation_results) + # Build checkpoint with the background investigation results + log_graph_event( + configurable.thread_id, + "background_investigator", + "info", + {"goto": "planner", "investigations": results}, + ) + return {"background_investigation_results": results} else: logger.error( f"Tavily search returned malformed response: {searched_content}" @@ -73,11 +79,15 @@ def background_investigation_node(state: State, config: RunnableConfig): background_investigation_results = get_web_search_tool( configurable.max_search_results ).invoke(query) - return { - "background_investigation_results": json.dumps( - background_investigation_results, ensure_ascii=False - ) - } + results = json.dumps(background_investigation_results, ensure_ascii=False) + # Build checkpoint with the background investigation results + log_graph_event( + configurable.thread_id, + "background_investigator", + "info", + {"goto": "planner", "investigations": results}, + ) + return {"background_investigation_results": results} def planner_node( @@ -139,6 +149,13 @@ def planner_node( if isinstance(curr_plan, dict) and curr_plan.get("has_enough_context"): logger.info("Planner response has enough context.") new_plan = Plan.model_validate(curr_plan) + # Build checkpoint with the current plan + log_graph_event( + configurable.thread_id, + "planner", + "info", + {"goto": "reporter", "current_plan": curr_plan}, + ) return Command( update={ "messages": [AIMessage(content=full_response, name="planner")], @@ -146,6 +163,13 @@ def planner_node( }, goto="reporter", ) + # Build checkpoint with the current plan + log_graph_event( + configurable.thread_id, + "planner", + "info", + {"goto": "human_feedback", "current_plan": curr_plan}, + ) return Command( update={ "messages": [AIMessage(content=full_response, name="planner")], @@ -156,8 +180,9 @@ def planner_node( def human_feedback_node( - state, + state, config: RunnableConfig ) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]: + configurable = Configuration.from_runnable_config(config) current_plan = state.get("current_plan", "") # check if the plan is auto accepted auto_accepted_plan = state.get("auto_accepted_plan", False) @@ -194,7 +219,13 @@ def human_feedback_node( return Command(goto="reporter") else: return Command(goto="__end__") - + # Build checkpoint with the current plan + log_graph_event( + configurable.thread_id, + "human_feedback", + "info", + {"goto": goto, "current_plan": new_plan, "plan_iterations": plan_iterations}, + ) return Command( update={ "current_plan": Plan.model_validate(new_plan), @@ -248,6 +279,18 @@ def coordinator_node( messages = state.get("messages", []) if response.content: messages.append(HumanMessage(content=response.content, name="coordinator")) + + # Build checkpoint with the current plan + log_research_replays( + configurable.thread_id, research_topic, configurable.report_style, 0 + ) + log_graph_event( + configurable.thread_id, + "coordinator", + "info", + {"goto": goto, "research_topic": research_topic}, + ) + return Command( update={ "messages": messages, @@ -294,7 +337,13 @@ def reporter_node(state: State, config: RunnableConfig): response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages) response_content = response.content logger.info(f"reporter response: {response_content}") - + # Build checkpoint with the current plan + log_graph_event( + configurable.thread_id, + "reporter", + "info", + {"goto": "end", "final_report": response_content}, + ) return {"final_report": response_content} @@ -305,13 +354,13 @@ def research_team_node(state: State): async def _execute_agent_step( - state: State, agent, agent_name: str + state: State, config: RunnableConfig, agent, agent_name: str ) -> Command[Literal["research_team"]]: """Helper function to execute a step using the specified agent.""" current_plan = state.get("current_plan") plan_title = current_plan.title observations = state.get("observations", []) - + configurable = Configuration.from_runnable_config(config) # Find the first unexecuted step current_step = None completed_steps = [] @@ -402,6 +451,28 @@ async def _execute_agent_step( # Update the step with the execution result current_step.execution_res = response_content logger.info(f"Step '{current_step.title}' execution completed by {agent_name}") + # Build checkpoint with the current plan + agent_input_messages = [] + for message in agent_input["messages"]: + if isinstance(message, tuple): + agent_input_messages.append( + { + "role": message.type, + "content": message.content, + "name": message.name, + } + ) + log_graph_event( + configurable.thread_id, + "agent", + "info", + { + "goto": "research_team", + "agent": agent_name, + "input": agent_input_messages, + "observations": observations + [response_content], + }, + ) return Command( update={ @@ -470,11 +541,11 @@ async def _setup_and_execute_agent_step( ) loaded_tools.append(tool) agent = create_agent(agent_type, agent_type, loaded_tools, agent_type) - return await _execute_agent_step(state, agent, agent_type) + return await _execute_agent_step(state, config, agent, agent_type) else: # Use default tools if no MCP servers are configured agent = create_agent(agent_type, agent_type, default_tools, agent_type) - return await _execute_agent_step(state, agent, agent_type) + return await _execute_agent_step(state, config, agent, agent_type) async def researcher_node( diff --git a/src/server/app.py b/src/server/app.py index d68412dc9..901074614 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -47,7 +47,19 @@ RAGResourceRequest, RAGResourcesResponse, ) +from src.server.conversation_request import ( + Conversation, + ConversationsRequest, + ConversationsResponse, +) from src.tools import VolcengineTTS + +from src.graph.checkpoint import ( + chat_stream_message, + get_conversation, + list_conversations, +) + from src.utils.json_utils import sanitize_args logger = logging.getLogger(__name__) @@ -611,3 +623,48 @@ async def config(): rag=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER), models=get_configured_llm_models(), ) + + +@app.get("/api/conversation/{thread_id}", response_model=str) +async def get_converstation(thread_id: str) -> Response: + """Get the Conversation content for a specific thread ID.""" + try: + content = get_conversation(thread_id) + if not content: + raise HTTPException(status_code=404, detail="Converstation not found") + + return Response( + content=content, + media_type="text/plain", + headers={"Content-Type": "text/plain; charset=utf-8"}, + ) + + except Exception as e: + logger.exception(f"Error getting Converstation: {str(e)}") + raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL) + + +@app.get("/api/conversations", response_model=ConversationsResponse) +async def get_conversations( + request: Annotated[ConversationsRequest, Query()], +) -> ConversationsResponse: + """Get conversations based on the provided request parameters.""" + try: + conversations = list_conversations(limit=request.limit, sort=request.sort) + response = [] + for conversation in conversations: + response.append( + Conversation( + id=conversation.get("thread_id", ""), + title=conversation.get("research_topic", ""), + count=conversation.get("messages", 0), + date=conversation.get("ts", ""), + category=conversation.get("report_style", ""), + data_type="database", # Assuming default data type is 'txt' + ) + ) + data = ConversationsResponse(data=response) + return data + except Exception as e: + logger.exception(f"Error getting conversations: {str(e)}") + raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL) diff --git a/src/server/conversation_request.py b/src/server/conversation_request.py new file mode 100644 index 000000000..3af6f1d51 --- /dev/null +++ b/src/server/conversation_request.py @@ -0,0 +1,53 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field + + +class Conversation(BaseModel): + id: Optional[str] = Field("", description="The thread ID of the conversation.") + title: Optional[str] = Field("", description="The title of the conversation") + date: Optional[datetime] = Field( + "", description="The date of the conversation, formatted as 'YYYY-MM-DD'." + ) + + category: Optional[str] = Field( + "Social Media", description="The writing style of the conversation." + ) + count: Optional[int] = Field( + 0, description="The number of messages in the conversation." + ) + data_type: Optional[str] = Field( + "txt", description="The type of data in the conversation, e.g., 'txt', 'json'." + ) + + +class ConversationsResponse(BaseModel): + + data: Optional[list[Conversation]] = Field( + default_factory=list, + description="List of replays matching the request criteria", + ) + + +class ConversationsRequest(BaseModel): + """Request model for RAG resource queries. + + This model represents a request to search for resources within the RAG system. + It encapsulates the search query and any associated parameters. + + Attributes: + query: The search query string used to find relevant resources. + Can be None if no specific query is provided. + """ + + limit: Optional[int] = Field( + None, description="The maximum number of resources to retrieve" + ) + offset: Optional[int] = Field( + None, + description="The offset for pagination, used to skip a number of resources", + ) + sort: Optional[str] = Field( + None, description="The field by which to sort the resources" + ) diff --git a/src/tools/tavily_search/tavily_search_api_wrapper.py b/src/tools/tavily_search/tavily_search_api_wrapper.py index f1945a5bd..536efe061 100644 --- a/src/tools/tavily_search/tavily_search_api_wrapper.py +++ b/src/tools/tavily_search/tavily_search_api_wrapper.py @@ -104,10 +104,19 @@ def clean_results_with_images( clean_results.append(clean_result) images = raw_results["images"] for image in images: - clean_result = { - "type": "image", - "image_url": image["url"], - "image_description": image["description"], - } + if isinstance(image, str): + clean_result = { + "type": "image", + "image_url": image, + "image_description": "", + } + elif isinstance(image, dict): + clean_result = { + "type": "image", + "image_url": image.get("url"), + "image_description": image.get("description", ""), + } + else: + continue clean_results.append(clean_result) return clean_results diff --git a/src/tools/tavily_search/tavily_search_results_with_images.py b/src/tools/tavily_search/tavily_search_results_with_images.py index 7ecde9eef..28f402baf 100644 --- a/src/tools/tavily_search/tavily_search_results_with_images.py +++ b/src/tools/tavily_search/tavily_search_results_with_images.py @@ -3,7 +3,7 @@ import json import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 161e55688..0ebfbf986 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -52,6 +52,14 @@ def mock_config(): return MagicMock() +@pytest.fixture +def mock_config_thread(): + # 你可以根据实际需要返回一个 MagicMock 或 dict + mock = MagicMock() + mock.thread_id = "_default_" + return mock + + @pytest.fixture def patch_config_from_runnable_config(mock_configurable): with patch( @@ -416,52 +424,58 @@ def mock_state_base(): } -def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base): +def test_human_feedback_node_auto_accepted( + monkeypatch, mock_state_base, mock_config_thread +): # auto_accepted_plan True, should skip interrupt and parse plan state = dict(mock_state_base) state["auto_accepted_plan"] = True - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config_thread) assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["plan_iterations"] == 1 assert result.update["current_plan"]["has_enough_context"] is False -def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base): +def test_human_feedback_node_edit_plan( + monkeypatch, mock_state_base, mock_config_thread +): # interrupt returns [EDIT_PLAN]..., should return Command to planner state = dict(mock_state_base) state["auto_accepted_plan"] = False with patch("src.graph.nodes.interrupt", return_value="[EDIT_PLAN] Please revise"): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config_thread) assert isinstance(result, Command) assert result.goto == "planner" assert result.update["messages"][0].name == "feedback" assert "[EDIT_PLAN]" in result.update["messages"][0].content -def test_human_feedback_node_accepted(monkeypatch, mock_state_base): +def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config_thread): # interrupt returns [ACCEPTED]..., should proceed to parse plan state = dict(mock_state_base) state["auto_accepted_plan"] = False with patch("src.graph.nodes.interrupt", return_value="[ACCEPTED] Looks good!"): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config_thread) assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["plan_iterations"] == 1 assert result.update["current_plan"]["has_enough_context"] is False -def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base): +def test_human_feedback_node_invalid_interrupt( + monkeypatch, mock_state_base, mock_config_thread +): # interrupt returns something else, should raise TypeError state = dict(mock_state_base) state["auto_accepted_plan"] = False with patch("src.graph.nodes.interrupt", return_value="RANDOM_FEEDBACK"): with pytest.raises(TypeError): - human_feedback_node(state) + human_feedback_node(state, mock_config_thread) def test_human_feedback_node_json_decode_error_first_iteration( - monkeypatch, mock_state_base + monkeypatch, mock_state_base, mock_config_thread ): # repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations=0 state = dict(mock_state_base) @@ -470,13 +484,13 @@ def test_human_feedback_node_json_decode_error_first_iteration( with patch( "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) ): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config_thread) assert isinstance(result, Command) assert result.goto == "__end__" def test_human_feedback_node_json_decode_error_second_iteration( - monkeypatch, mock_state_base + monkeypatch, mock_state_base, mock_config_thread ): # repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations>0 state = dict(mock_state_base) @@ -485,12 +499,14 @@ def test_human_feedback_node_json_decode_error_second_iteration( with patch( "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) ): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config_thread) assert isinstance(result, Command) assert result.goto == "reporter" -def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base): +def test_human_feedback_node_not_enough_context( + monkeypatch, mock_state_base, mock_config_thread +): # Plan does not have enough context, should goto research_team plan = { "has_enough_context": False, @@ -502,7 +518,7 @@ def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base): state = dict(mock_state_base) state["current_plan"] = json.dumps(plan) state["auto_accepted_plan"] = True - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config_thread) assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["plan_iterations"] == 1 @@ -905,14 +921,16 @@ async def ainvoke(input, config): @pytest.mark.asyncio -async def test_execute_agent_step_basic(mock_state_with_steps, mock_agent): +async def test_execute_agent_step_basic( + mock_state_with_steps, mock_agent, mock_config_thread +): # Should execute the first unexecuted step and update execution_res with patch( "src.graph.nodes.HumanMessage", side_effect=lambda content, name=None: MagicMock(content=content, name=name), ): result = await _execute_agent_step( - mock_state_with_steps, mock_agent, "researcher" + mock_state_with_steps, mock_config_thread, mock_agent, "researcher" ) assert isinstance(result, Command) assert result.goto == "research_team" @@ -929,12 +947,12 @@ async def test_execute_agent_step_basic(mock_state_with_steps, mock_agent): @pytest.mark.asyncio async def test_execute_agent_step_no_unexecuted_step( - mock_state_no_unexecuted, mock_agent + mock_state_no_unexecuted, mock_agent, mock_config_thread ): # Should return Command with goto="research_team" and not fail with patch("src.graph.nodes.logger") as mock_logger: result = await _execute_agent_step( - mock_state_no_unexecuted, mock_agent, "researcher" + mock_state_no_unexecuted, mock_config_thread, mock_agent, "researcher" ) assert isinstance(result, Command) assert result.goto == "research_team" @@ -964,11 +982,13 @@ async def ainvoke(input, config): return {"messages": [MagicMock(content="resource result")]} agent.ainvoke = ainvoke + config = MagicMock() + config.thread_id = "test_thread" with patch( "src.graph.nodes.HumanMessage", side_effect=lambda content, name=None: MagicMock(content=content, name=name), ): - result = await _execute_agent_step(state, agent, "researcher") + result = await _execute_agent_step(state, config, agent, "researcher") assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["observations"][-1] == "resource result" @@ -976,7 +996,7 @@ async def ainvoke(input, config): @pytest.mark.asyncio async def test_execute_agent_step_recursion_limit_env( - monkeypatch, mock_state_with_steps, mock_agent + monkeypatch, mock_state_with_steps, mock_agent, mock_config_thread ): # Should respect AGENT_RECURSION_LIMIT env variable if set and valid monkeypatch.setenv("AGENT_RECURSION_LIMIT", "42") @@ -989,14 +1009,16 @@ async def test_execute_agent_step_recursion_limit_env( ), ), ): - result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder") + result = await _execute_agent_step( + mock_state_with_steps, mock_config_thread, mock_agent, "coder" + ) assert isinstance(result, Command) mock_logger.info.assert_any_call("Recursion limit set to: 42") @pytest.mark.asyncio async def test_execute_agent_step_recursion_limit_env_invalid( - monkeypatch, mock_state_with_steps, mock_agent + monkeypatch, mock_state_with_steps, mock_agent, mock_config_thread ): # Should fallback to default if env variable is invalid monkeypatch.setenv("AGENT_RECURSION_LIMIT", "notanint") @@ -1009,7 +1031,9 @@ async def test_execute_agent_step_recursion_limit_env_invalid( ), ), ): - result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder") + result = await _execute_agent_step( + mock_state_with_steps, mock_config_thread, mock_agent, "coder" + ) assert isinstance(result, Command) mock_logger.warning.assert_any_call( "Invalid AGENT_RECURSION_LIMIT value: 'notanint'. Using default value 25." @@ -1018,7 +1042,7 @@ async def test_execute_agent_step_recursion_limit_env_invalid( @pytest.mark.asyncio async def test_execute_agent_step_recursion_limit_env_negative( - monkeypatch, mock_state_with_steps, mock_agent + monkeypatch, mock_state_with_steps, mock_agent, mock_config_thread ): # Should fallback to default if env variable is negative or zero monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5") @@ -1031,7 +1055,9 @@ async def test_execute_agent_step_recursion_limit_env_negative( ), ), ): - result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder") + result = await _execute_agent_step( + mock_state_with_steps, mock_config_thread, mock_agent, "coder" + ) assert isinstance(result, Command) mock_logger.warning.assert_any_call( "AGENT_RECURSION_LIMIT value '-5' (parsed as -5) is not positive. Using default value 25." @@ -1091,7 +1117,7 @@ def patch_create_agent(): @pytest.fixture def patch_execute_agent_step(): - async def fake_execute_agent_step(state, agent, agent_type): + async def fake_execute_agent_step(state, config, agent, agent_type): return "EXECUTED" with patch( diff --git a/web/src/app/chat/components/input-box.tsx b/web/src/app/chat/components/input-box.tsx index ae18ec495..760cb366e 100644 --- a/web/src/app/chat/components/input-box.tsx +++ b/web/src/app/chat/components/input-box.tsx @@ -7,6 +7,7 @@ import { ArrowUp, Lightbulb, X } from "lucide-react"; import { useTranslations } from "next-intl"; import { useCallback, useRef, useState } from "react"; + import { Detective } from "~/components/deer-flow/icons/detective"; import MessageInput, { type MessageInputRef, diff --git a/web/src/app/chat/components/messages-block.tsx b/web/src/app/chat/components/messages-block.tsx index 31c125f84..fcee9d779 100644 --- a/web/src/app/chat/components/messages-block.tsx +++ b/web/src/app/chat/components/messages-block.tsx @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT import { motion } from "framer-motion"; -import { FastForward, Play } from "lucide-react"; +import { FastForward, Play, CornerDownLeft } from "lucide-react"; import { useTranslations } from "next-intl"; import { useCallback, useRef, useState } from "react"; @@ -59,7 +59,7 @@ export function MessagesBlock({ className }: { className?: string }) { abortSignal: abortController.signal, }, ); - } catch {} + } catch { } }, [feedback], ); @@ -186,6 +186,17 @@ export function MessagesBlock({ className }: { className?: string }) { {t("play")} )} + {!responding && replayStarted && ( + + + )} )} diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index 953666939..61d800f2d 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -14,8 +14,10 @@ import { Button } from "~/components/ui/button"; import { Logo } from "../../components/deer-flow/logo"; import { ThemeToggle } from "../../components/deer-flow/theme-toggle"; import { Tooltip } from "../../components/deer-flow/tooltip"; +import { ConversationsDialog } from "../settings/dialogs/conversations-dialog"; import { SettingsDialog } from "../settings/dialogs/settings-dialog"; + const Main = dynamic(() => import("./main"), { ssr: false, loading: () => ( @@ -44,6 +46,9 @@ export default function HomePage() { + + + diff --git a/web/src/app/settings/dialogs/conversations-dialog.tsx b/web/src/app/settings/dialogs/conversations-dialog.tsx new file mode 100644 index 000000000..b6b5375f4 --- /dev/null +++ b/web/src/app/settings/dialogs/conversations-dialog.tsx @@ -0,0 +1,153 @@ +//import { VerticalAlignBottomOutlined } from "@ant-design/icons"; +import { MessageSquareReply, Play, FileText, Newspaper, Users, GraduationCap, CircleCheck, CircleX, CircleAlert, Ellipsis } from "lucide-react"; +import { useState } from "react"; + +import { LoadingAnimation } from "~/components/deer-flow/loading-animation"; +import { RainbowText } from "~/components/deer-flow/rainbow-text"; +import { Tooltip } from "~/components/deer-flow/tooltip"; +import { Button } from "~/components/ui/button"; +import { + Card, + CardDescription, + CardHeader, + CardTitle, +} from "~/components/ui/card"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "~/components/ui/dialog"; +import { useConversations } from "~/core/api/hooks"; +import { cn } from "~/lib/utils"; + +export function ConversationsDialog() { + const [open, setOpen] = useState(false); + // Fetch conversations when dialog opens + const { results, loading } = useConversations(); + + const handleOpenChange = (isOpen: boolean) => { + setOpen(isOpen); + } + + const conversations = [ + ...(results ?? []), + // Placeholder for replays data + { id: "ai-twin-insurance", data_type: 'txt', title: "Write an article on \"Would you insure your AI twin?\"", date: "2025/5/19 12:54", category: "Social Media", count: 500 }, + { id: "china-food-delivery", data_type: 'txt', title: "如何看待外卖大战", date: "2025/5/20 14:30", category: "Research", count: 1000 }, + { id: "eiffel-tower-vs-tallest-building", data_type: 'txt', title: "How many times taller is the Eiffel Tower than the tallest building in the world?", date: "2025/5/21 16:45", category: "Technology", count: 8 }, + { id: "github-top-trending-repo", data_type: 'txt', title: "Write a brief on the top 1 trending repo on Github today.", date: "2025/5/22 18:00", category: "Education", count: 120 }, + { id: "nanjing-traditional-dishes", data_type: 'txt', title: "Write an article about Nanjing's traditional dishes.", date: "2025/5/23 20:15", category: "Health", count: 60 }, + { id: "rental-apartment-decoration", data_type: 'txt', title: "How to decorate a small rental apartment?", date: "2025/5/23 20:15", category: "Health", count: 116 }, + { id: "review-of-the-professional", data_type: 'txt', title: "Introduce the movie 'Léon: The Professional'", date: "2025/5/23 20:15", category: "Health", count: 678 }, + { id: "ultra-processed-foods", data_type: 'txt', title: "Are ultra-processed foods linked to health?", date: "2025/5/23 20:15", category: "Health", count: 600 }, + ]; // Placeholder for replays data + + return ( + + + + + + + + + Conversations + + Replay your conversations here. + + +
+ {loading ? ( +
+ +
+ ) : conversations.length === 0 ? ( +
+

No conversations found.

+
+ ) : (<>) + } + {conversations.map((result) => ( +
+ +
+
+ { + result.category === "social_media" ? ( + + ) : result.category === "news" ? ( + + ) : result.category === "academic" ? ( + + ) : result.category === "popular_science" ? ( + + ) : ( + + ) + } +
+
+ + + + + {`${result.title}`} + + + + + {`${result.date.substring(0, 19).replace(/-/g, "/").replace("T", " ")} | ${result.category} | ${result.count} messages`} + + { + result.count === 0 ? ( + + ) : result.count > 800 ? ( + + ) : result.count < 800 && result.count > 100 ? ( + + ) : ( + + ) + } + + +
+
+ +
+
+
+
+ ))} + +
+ + + + + +
+
+ ); +} diff --git a/web/src/app/settings/tabs/about-tab.tsx b/web/src/app/settings/tabs/about-tab.tsx index 14c508892..708ff90c6 100644 --- a/web/src/app/settings/tabs/about-tab.tsx +++ b/web/src/app/settings/tabs/about-tab.tsx @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT import { BadgeInfo } from "lucide-react"; -import { useLocale, useTranslations } from "next-intl"; +import { useLocale } from "next-intl"; import { Markdown } from "~/components/deer-flow/markdown"; diff --git a/web/src/core/api/chat.ts b/web/src/core/api/chat.ts index a992d4642..067fe3358 100644 --- a/web/src/core/api/chat.ts +++ b/web/src/core/api/chat.ts @@ -5,10 +5,11 @@ import { env } from "~/env"; import type { MCPServerMetadata } from "../mcp"; import type { Resource } from "../messages"; -import { extractReplayIdFromSearchParams } from "../replay/get-replay-id"; +import { extractFromSearchParams } from "../replay/get-replay-id"; import { fetchStream } from "../sse"; import { sleep } from "../utils"; +import { queryConversationByPath } from "./conversations"; import { resolveServiceURL } from "./resolve-service-url"; import type { ChatEvent } from "./types"; @@ -40,11 +41,12 @@ export async function* chatStream( if ( env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY || location.search.includes("mock") || - location.search.includes("replay=") - ) + location.search.includes("replay=") || + location.search.includes("thread_id=") + ) return yield* chatReplayStream(userMessage, params, options); - - try{ + + try { const stream = fetchStream(resolveServiceURL("chat/stream"), { body: JSON.stringify({ messages: [{ role: "user", content: userMessage }], @@ -52,14 +54,14 @@ export async function* chatStream( }), signal: options.abortSignal, }); - + for await (const event of stream) { yield { type: event.event, data: JSON.parse(event.data), } as ChatEvent; } - }catch(e){ + } catch (e) { console.error(e); } } @@ -74,13 +76,13 @@ async function* chatReplayStream( max_search_results?: number; interrupt_feedback?: string; } = { - thread_id: "__mock__", - auto_accepted_plan: false, - max_plan_iterations: 3, - max_step_num: 1, - max_search_results: 3, - interrupt_feedback: undefined, - }, + thread_id: "__mock__", + auto_accepted_plan: false, + max_plan_iterations: 3, + max_step_num: 1, + max_search_results: 3, + interrupt_feedback: undefined, + }, options: { abortSignal?: AbortSignal } = {}, ): AsyncIterable { const urlParams = new URLSearchParams(window.location.search); @@ -98,8 +100,17 @@ async function* chatReplayStream( } } fastForwardReplaying = true; + } else if (urlParams.has("thread_id")) { + const threadId = extractFromSearchParams(window.location.search, "thread_id"); + if (threadId) { + replayFilePath = `/api/conversation/${threadId}`; + } else { + // Fallback to a default replay + replayFilePath = `/replay/eiffel-tower-vs-tallest-building.txt`; + } + fastForwardReplaying = true; } else { - const replayId = extractReplayIdFromSearchParams(window.location.search); + const replayId = extractFromSearchParams(window.location.search, "replay"); if (replayId) { replayFilePath = `/replay/${replayId}.txt`; } else { @@ -107,7 +118,9 @@ async function* chatReplayStream( replayFilePath = `/replay/eiffel-tower-vs-tallest-building.txt`; } } - const text = await fetchReplay(replayFilePath, { + const text = replayFilePath.startsWith("/api/conversation") ? await queryConversationByPath(replayFilePath, { + abortSignal: options.abortSignal, + }) : await fetchReplay(replayFilePath, { abortSignal: options.abortSignal, }); const normalizedText = text.replace(/\r\n/g, "\n"); diff --git a/web/src/core/api/conversations.ts b/web/src/core/api/conversations.ts new file mode 100644 index 000000000..a38248423 --- /dev/null +++ b/web/src/core/api/conversations.ts @@ -0,0 +1,58 @@ +import type { Conversation } from "../messages"; + +import { resolveServiceURL } from "./resolve-service-url"; + +export async function queryConversations() { + const response = await fetch(resolveServiceURL(`conversations?limit=100&sort=ts&offset=0`), { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }) + .then((res) => res.json()) + .then((res) => { + return res.data ? res.data as Array : []; + }) + .catch(() => { + return []; + }); + return response; +} + +export async function querConversationById(thread_id: string) { + const response = await fetch(resolveServiceURL(`conversation/${thread_id}`), { + method: "GET", + headers: { + "Content-Type": "text/plain; charset=UTF-8", + }, + }) + .then((res) => res.text()) + .then((res) => { + return res; + }) + .catch(() => { + return ""; + }); + return response; +} + + +export async function queryConversationByPath(path: string, options: { abortSignal?: AbortSignal } = {},) { + + const response = await fetch(resolveServiceURL(`${path.substring(5)}`), { + method: "GET", + headers: { + "Content-Type": "text/plain; charset=UTF-8", + }, + signal: options.abortSignal, + }) + .then((res) => res.text()) + .then((res) => { + return res; + }) + .catch(() => { + return `Failed to fetch conversation by path: ${path}`; + }); + + return response; +} \ No newline at end of file diff --git a/web/src/core/api/hooks.ts b/web/src/core/api/hooks.ts index 133cc6259..789bec34c 100644 --- a/web/src/core/api/hooks.ts +++ b/web/src/core/api/hooks.ts @@ -6,9 +6,11 @@ import { useEffect, useRef, useState } from "react"; import { env } from "~/env"; import type { DeerFlowConfig } from "../config"; +import type { Conversation } from "../messages"; import { useReplay } from "../replay"; import { fetchReplayTitle } from "./chat"; +import { queryConversations } from "./conversations"; import { resolveServiceURL } from "./resolve-service-url"; export function useReplayMetadata() { @@ -71,3 +73,47 @@ export function useConfig(): { return { config, loading }; } + +export function useConversations(): { + results: Conversation[] | null; + loading: boolean; +} { + const [results, setResults] = useState>([]); + const [loading, setLoading] = useState(true); + const hasInitialized = useRef(false); + const maxRetries = useRef(3); + + useEffect(() => { + if (env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY) { + setLoading(false); + return; + } + // Prevent multiple calls + if (hasInitialized.current || maxRetries.current <= 0) { + return; + } + + queryConversations() + .then((data) => { + setResults(data); + setLoading(false); + hasInitialized.current = true; + maxRetries.current = 0; // Reset retries after successful fetch + }) + .catch((error) => { + console.error("Failed to fetch replays", error); + setLoading(false); + if (maxRetries.current > 0) { + maxRetries.current -= 1; + console.warn(`Retrying... (${3 - maxRetries.current} attempts left)`); + } + }); + + return () => { + hasInitialized.current = false; + maxRetries.current = 3; // Reset retries on unmount + }; + }, []); + + return { results, loading }; +} \ No newline at end of file diff --git a/web/src/core/messages/types.ts b/web/src/core/messages/types.ts index c4dd9ff94..efb05ccfb 100644 --- a/web/src/core/messages/types.ts +++ b/web/src/core/messages/types.ts @@ -7,12 +7,12 @@ export interface Message { id: string; threadId: string; agent?: - | "coordinator" - | "planner" - | "researcher" - | "coder" - | "reporter" - | "podcast"; + | "coordinator" + | "planner" + | "researcher" + | "coder" + | "reporter" + | "podcast"; role: MessageRole; isStreaming?: boolean; content: string; @@ -43,3 +43,11 @@ export interface Resource { uri: string; title: string; } +export interface Conversation { + id: string; + title: string; + count: number; + date: string; + category: string; + data_type: string; +} diff --git a/web/src/core/replay/get-replay-id.ts b/web/src/core/replay/get-replay-id.ts index 4b124bcd6..0400acf45 100644 --- a/web/src/core/replay/get-replay-id.ts +++ b/web/src/core/replay/get-replay-id.ts @@ -1,10 +1,10 @@ // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates // SPDX-License-Identifier: MIT -export function extractReplayIdFromSearchParams(params: string) { +export function extractFromSearchParams(params: string, name = "replay") { const urlParams = new URLSearchParams(params); - if (urlParams.has("replay")) { - return urlParams.get("replay"); + if (urlParams.has(name)) { + return urlParams.get(name); } return null; -} +} \ No newline at end of file diff --git a/web/src/core/replay/hooks.ts b/web/src/core/replay/hooks.ts index 27f1f4c2b..1f35ee876 100644 --- a/web/src/core/replay/hooks.ts +++ b/web/src/core/replay/hooks.ts @@ -6,12 +6,12 @@ import { useMemo } from "react"; import { env } from "~/env"; -import { extractReplayIdFromSearchParams } from "./get-replay-id"; +import { extractFromSearchParams } from "./get-replay-id"; export function useReplay() { const searchParams = useSearchParams(); const replayId = useMemo( - () => extractReplayIdFromSearchParams(searchParams.toString()), + () => extractFromSearchParams(searchParams.toString(), "replay") ?? extractFromSearchParams(searchParams.toString(), "thread_id"), [searchParams], ); return {