diff --git a/README.md b/README.md index 4b11591..9c12dd1 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,8 @@ servers: auth: # optional username: "user" password: "pass" + disabled_tools: [] # optional, list of tools to disable for this server +disabled_tools: [] # optional, global list of tools to disable for all servers mcp: transports: - streamable-http # streamable-http or stdio. @@ -185,7 +187,7 @@ The MCP server provides **18 specialized tools** organized by analysis patterns. *SQL performance analysis and execution plan comparison* | 🔧 Tool | 📝 Description | |---------|----------------| -| `list_slowest_sql_queries` | 🐌 Get the top N slowest SQL queries for an application with detailed execution metrics | +| `list_slowest_sql_queries` | 🐌 Get the top N slowest SQL queries for an application with detailed execution metrics and optional plan descriptions | | `compare_sql_execution_plans` | 🔍 Compare SQL execution plans between two Spark jobs, analyzing logical/physical plans and execution metrics | ### 🚨 Performance & Bottleneck Analysis @@ -302,6 +304,7 @@ SHS_SERVERS_*_AUTH_TOKEN - Token for a specific server SHS_SERVERS_*_VERIFY_SSL - Whether to verify SSL for a specific server (true/false) SHS_SERVERS_*_TIMEOUT - HTTP request timeout in seconds for a specific server (default: 30) SHS_SERVERS_*_EMR_CLUSTER_ARN - EMR cluster ARN for a specific server +SHS_SERVERS_*_INCLUDE_PLAN_DESCRIPTION - Whether to include SQL execution plans by default for a specific server (true/false, default: false) ``` ## 🤖 AI Agent Integration diff --git a/src/spark_history_mcp/config/config.py b/src/spark_history_mcp/config/config.py index 639f9e6..37377a5 100644 --- a/src/spark_history_mcp/config/config.py +++ b/src/spark_history_mcp/config/config.py @@ -28,6 +28,7 @@ class ServerConfig(BaseSettings): emr_cluster_arn: Optional[str] = None # EMR specific field use_proxy: bool = False timeout: int = 30 # HTTP request timeout in seconds + disabled_tools: List[str] = Field(default_factory=list) # Tools to disable class McpConfig(BaseSettings): diff --git a/src/spark_history_mcp/tools/tools.py b/src/spark_history_mcp/tools/tools.py index 94d3dc7..56b04fe 100644 --- a/src/spark_history_mcp/tools/tools.py +++ b/src/spark_history_mcp/tools/tools.py @@ -18,6 +18,7 @@ TaskMetricDistributions, ) +from ..utils.tool_filter import conditional_tool from ..utils.utils import parallel_execute logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ def get_client_or_default( ) -@mcp.tool() +@conditional_tool(mcp) def list_applications( server: Optional[str] = None, status: Optional[list[str]] = None, @@ -130,7 +131,7 @@ def list_applications( return all_apps -@mcp.tool() +@conditional_tool(mcp) def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInfo: """ Get detailed information about a specific Spark application. @@ -151,7 +152,7 @@ def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInf return client.get_application(app_id) -@mcp.tool() +@conditional_tool(mcp) def list_jobs( app_id: str, server: Optional[str] = None, status: Optional[list[str]] = None ) -> list: @@ -177,7 +178,7 @@ def list_jobs( return client.list_jobs(app_id=app_id, status=job_statuses) -@mcp.tool() +@conditional_tool(mcp) def list_slowest_jobs( app_id: str, server: Optional[str] = None, @@ -222,7 +223,7 @@ def get_job_duration(job): return heapq.nlargest(n, jobs, key=get_job_duration) -@mcp.tool() +@conditional_tool(mcp) def list_stages( app_id: str, server: Optional[str] = None, @@ -259,7 +260,7 @@ def list_stages( ) -@mcp.tool() +@conditional_tool(mcp) def list_slowest_stages( app_id: str, server: Optional[str] = None, @@ -302,7 +303,7 @@ def get_stage_duration(stage: StageData): return heapq.nlargest(n, stages, key=get_stage_duration) -@mcp.tool() +@conditional_tool(mcp) def get_stage( app_id: str, stage_id: int, @@ -368,7 +369,7 @@ def get_stage( return stage_data -@mcp.tool() +@conditional_tool(mcp) def get_environment(app_id: str, server: Optional[str] = None): """ Get the comprehensive Spark runtime configuration for a Spark application. @@ -389,7 +390,7 @@ def get_environment(app_id: str, server: Optional[str] = None): return client.get_environment(app_id=app_id) -@mcp.tool() +@conditional_tool(mcp) def list_executors( app_id: str, server: Optional[str] = None, include_inactive: bool = False ): @@ -416,7 +417,7 @@ def list_executors( return client.list_executors(app_id=app_id) -@mcp.tool() +@conditional_tool(mcp) def get_executor(app_id: str, executor_id: str, server: Optional[str] = None): """ Get information about a specific executor. @@ -445,7 +446,7 @@ def get_executor(app_id: str, executor_id: str, server: Optional[str] = None): return None -@mcp.tool() +@conditional_tool(mcp) def get_executor_summary(app_id: str, server: Optional[str] = None): """ Aggregates metrics across all executors for a Spark application. @@ -467,7 +468,7 @@ def get_executor_summary(app_id: str, server: Optional[str] = None): return _calculate_executor_metrics(executors) -@mcp.tool() +@conditional_tool(mcp) def compare_job_environments( app_id1: str, app_id2: str, server: Optional[str] = None ) -> Dict[str, Any]: @@ -582,7 +583,7 @@ def _calc_executor_summary_from_client(client, app_id: str): return _calculate_executor_metrics(executors) -@mcp.tool() +@conditional_tool(mcp) def compare_job_performance( app_id1: str, app_id2: str, server: Optional[str] = None ) -> Dict[str, Any]: @@ -736,7 +737,7 @@ def calc_job_stats(jobs): return comparison -@mcp.tool() +@conditional_tool(mcp) def compare_sql_execution_plans( app_id1: str, app_id2: str, @@ -857,7 +858,7 @@ def analyze_nodes(execution): return comparison -@mcp.tool() +@conditional_tool(mcp) def get_stage_task_summary( app_id: str, stage_id: int, @@ -914,7 +915,7 @@ def truncate_plan_description(plan_desc: str, max_length: int) -> str: return truncated + "\n... [truncated]" -@mcp.tool() +@conditional_tool(mcp) def list_slowest_sql_queries( app_id: str, server: Optional[str] = None, @@ -1010,7 +1011,7 @@ def list_slowest_sql_queries( return simplified_results -@mcp.tool() +@conditional_tool(mcp) def get_job_bottlenecks( app_id: str, server: Optional[str] = None, top_n: int = 5 ) -> Dict[str, Any]: @@ -1152,7 +1153,7 @@ def get_job_bottlenecks( return bottlenecks -@mcp.tool() +@conditional_tool(mcp) def get_resource_usage_timeline( app_id: str, server: Optional[str] = None ) -> Dict[str, Any]: diff --git a/src/spark_history_mcp/utils/tool_filter.py b/src/spark_history_mcp/utils/tool_filter.py new file mode 100644 index 0000000..83a5749 --- /dev/null +++ b/src/spark_history_mcp/utils/tool_filter.py @@ -0,0 +1,79 @@ +"""Tool filtering utilities for conditional MCP tool registration.""" + +import logging +import os +from typing import Callable, Optional, TypeVar + +from spark_history_mcp.config.config import Config + +F = TypeVar("F", bound=Callable) +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def is_tool_enabled(tool_name: str, config_path: str = "config.yaml") -> bool: + """ + Check if a tool is enabled based on configuration and environment variables. + + Args: + tool_name: Name of the tool to check + config_path: Path to configuration file (default: "config.yaml") + + Returns: + bool: True if tool is enabled, False if disabled + """ + # Check environment variable first (highest priority) + env_var = f"SHS_DISABLE_{tool_name.upper()}" + if os.getenv(env_var, "").lower() in ("true", "1", "yes"): + return False + + # Check global environment variable for disabled tools + disabled_tools_env = os.getenv("SHS_GLOBAL_DISABLED_TOOLS", "") + if disabled_tools_env: + disabled_tools = [tool.strip() for tool in disabled_tools_env.split(",")] + if tool_name in disabled_tools: + return False + + # Check configuration file + try: + config = Config.from_file(config_path) + + # Check if any server has this tool disabled + for server_config in config.servers.values(): + if tool_name in server_config.disabled_tools: + return False + + except Exception as e: + logger.error(f"Error loading configuration and loading disabled tools: {e}") + return True + + +def conditional_tool( + mcp_instance, tool_name: Optional[str] = None, config_path: str = "config.yaml" +): + """ + Decorator that conditionally registers an MCP tool based on configuration. + + Args: + mcp_instance: The FastMCP instance to register tools with + tool_name: Name of the tool (defaults to function name) + config_path: Path to configuration file + + Returns: + Decorator function + """ + + def decorator(func: F) -> F: + actual_tool_name = tool_name or func.__name__ + + if is_tool_enabled(actual_tool_name, config_path): + # Tool is enabled, register it with MCP + return mcp_instance.tool()(func) + else: + # Tool is disabled, return unregistered function + return func + + return decorator diff --git a/tests/unit/test_tool_filter.py b/tests/unit/test_tool_filter.py new file mode 100644 index 0000000..ffdcbea --- /dev/null +++ b/tests/unit/test_tool_filter.py @@ -0,0 +1,247 @@ +"""Tests for tool filtering functionality.""" + +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from spark_history_mcp.utils.tool_filter import ( + conditional_tool, + is_tool_enabled, +) + + +class TestToolFilter(unittest.TestCase): + """Test tool filtering functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Clear environment variables that might affect tests + env_vars_to_clear = [ + "SHS_GLOBAL_DISABLED_TOOLS", + "SHS_DISABLE_TEST_TOOL", + "SHS_SERVERS_LOCAL_DISABLED_TOOLS", + ] + for var in env_vars_to_clear: + if var in os.environ: + del os.environ[var] + + def tearDown(self): + """Clean up after tests.""" + # Clear any environment variables set during tests + env_vars_to_clear = [ + "SHS_GLOBAL_DISABLED_TOOLS", + "SHS_DISABLE_TEST_TOOL", + "SHS_SERVERS_LOCAL_DISABLED_TOOLS", + ] + for var in env_vars_to_clear: + if var in os.environ: + del os.environ[var] + + def test_is_tool_enabled_default(self): + """Test that tools are enabled by default.""" + # With non-existent config file, tools should be enabled by default + self.assertTrue(is_tool_enabled("test_tool", "non_existent_config.yaml")) + + def test_is_tool_enabled_with_global_disabled_tools_env(self): + """Test disabling tools via global environment variable.""" + os.environ["SHS_GLOBAL_DISABLED_TOOLS"] = "tool1,tool2,test_tool" + + self.assertFalse(is_tool_enabled("test_tool")) + self.assertFalse(is_tool_enabled("tool1")) + self.assertTrue(is_tool_enabled("enabled_tool")) + + def test_is_tool_enabled_with_individual_env_var(self): + """Test disabling individual tool via environment variable.""" + os.environ["SHS_DISABLE_TEST_TOOL"] = "true" + + self.assertFalse(is_tool_enabled("test_tool")) + self.assertTrue(is_tool_enabled("other_tool")) + + def test_is_tool_enabled_with_config_file(self): + """Test tool filtering via configuration file.""" + # Create a temporary config file + config_data = { + "servers": { + "local": { + "url": "http://localhost:18080", + "disabled_tools": ["server_disabled"], + } + }, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + import yaml + + yaml.dump(config_data, f) + config_path = f.name + + try: + # Test server-specific disabled tool + self.assertFalse(is_tool_enabled("server_disabled", config_path)) + + # Test enabled tool + self.assertTrue(is_tool_enabled("enabled_tool", config_path)) + finally: + os.unlink(config_path) + + def test_priority_order(self): + """Test that environment variables take priority over config file.""" + # Create config file that enables the tool + config_data = { + "servers": { + "local": {"url": "http://localhost:18080", "disabled_tools": []} + }, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + import yaml + + yaml.dump(config_data, f) + config_path = f.name + + try: + # Tool should be enabled by config + self.assertTrue(is_tool_enabled("test_tool", config_path)) + + # Environment variable should override config + os.environ["SHS_DISABLE_TEST_TOOL"] = "true" + self.assertFalse(is_tool_enabled("test_tool", config_path)) + + finally: + os.unlink(config_path) + + def test_conditional_tool_decorator_enabled(self): + """Test conditional tool decorator when tool is enabled.""" + mock_mcp = MagicMock() + mock_mcp.tool.return_value = lambda func: func + + @conditional_tool(mock_mcp, "enabled_tool") + def test_function(): + return "test" + + # Should be registered with MCP + mock_mcp.tool.assert_called_once() + + def test_conditional_tool_decorator_disabled(self): + """Test conditional tool decorator when tool is disabled.""" + os.environ["SHS_DISABLE_DISABLED_TOOL"] = "true" + + mock_mcp = MagicMock() + mock_mcp.tool.return_value = lambda func: func + + @conditional_tool(mock_mcp, "disabled_tool") + def test_function(): + return "test" + + # Should NOT be registered with MCP + mock_mcp.tool.assert_not_called() + + def test_conditional_tool_decorator_default_name(self): + """Test conditional tool decorator using function name.""" + mock_mcp = MagicMock() + mock_mcp.tool.return_value = lambda func: func + + @conditional_tool(mock_mcp) + def my_test_function(): + return "test" + + # Should use function name as tool name + mock_mcp.tool.assert_called_once() + + @patch("spark_history_mcp.utils.tool_filter.Config.from_file") + def test_config_loading_error_handling(self, mock_from_file): + """Test that config loading errors are handled gracefully.""" + # Make config loading raise an exception + mock_from_file.side_effect = Exception("Config loading failed") + + # Should default to enabled when config can't be loaded + self.assertTrue(is_tool_enabled("test_tool")) + + def test_environment_variable_parsing(self): + """Test parsing of comma-separated environment variables.""" + os.environ["SHS_GLOBAL_DISABLED_TOOLS"] = " tool1 , tool2 ,tool3" + + self.assertFalse(is_tool_enabled("tool1")) + self.assertFalse(is_tool_enabled("tool2")) + self.assertFalse(is_tool_enabled("tool3")) + self.assertTrue(is_tool_enabled("tool4")) + + def test_case_sensitivity(self): + """Test that tool names are case sensitive in config but not in env vars.""" + # Environment variables use uppercase conversion, so they're not case sensitive + os.environ["SHS_DISABLE_MYTEST"] = "true" + + # Both should be disabled because env var converts to uppercase + self.assertFalse(is_tool_enabled("mytest")) + self.assertFalse(is_tool_enabled("MYTEST")) + + # Test case sensitivity with config file + config_data = { + "servers": { + "local": { + "url": "http://localhost:18080", + "disabled_tools": ["lowercase_only"], + } + }, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + import yaml + + yaml.dump(config_data, f) + config_path = f.name + + try: + # Should be case sensitive in config + self.assertFalse(is_tool_enabled("lowercase_only", config_path)) + self.assertTrue( + is_tool_enabled("LOWERCASE_ONLY", config_path) + ) # Different case + finally: + os.unlink(config_path) + + def test_individual_env_var_values(self): + """Test different values for individual environment variables.""" + # Test "true" + os.environ["SHS_DISABLE_TEST1"] = "true" + self.assertFalse(is_tool_enabled("test1")) + + # Test "1" + os.environ["SHS_DISABLE_TEST2"] = "1" + self.assertFalse(is_tool_enabled("test2")) + + # Test "yes" + os.environ["SHS_DISABLE_TEST3"] = "yes" + self.assertFalse(is_tool_enabled("test3")) + + # Test "false" (should not disable) + os.environ["SHS_DISABLE_TEST4"] = "false" + self.assertTrue(is_tool_enabled("test4")) + + # Test empty string (should not disable) + os.environ["SHS_DISABLE_TEST5"] = "" + self.assertTrue(is_tool_enabled("test5")) + + # Clean up + for i in range(1, 6): + del os.environ[f"SHS_DISABLE_TEST{i}"] + + def test_whitespace_handling_in_global_env_var(self): + """Test that whitespace is properly stripped from global env var.""" + os.environ["SHS_GLOBAL_DISABLED_TOOLS"] = " tool1 , tool2 , tool3 " + + self.assertFalse(is_tool_enabled("tool1")) + self.assertFalse(is_tool_enabled("tool2")) + self.assertFalse(is_tool_enabled("tool3")) + self.assertTrue(is_tool_enabled("tool4")) + + def test_empty_global_env_var(self): + """Test that empty global env var doesn't affect anything.""" + os.environ["SHS_GLOBAL_DISABLED_TOOLS"] = "" + + self.assertTrue(is_tool_enabled("any_tool")) + + +if __name__ == "__main__": + unittest.main()