Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ servers:
auth: # optional
username: "user"
password: "pass"
disabled_tools: [] # optional, list of tools to disable for this server
disabled_tools: [] # optional, global list of tools to disable for all servers
mcp:
transports:
- streamable-http # streamable-http or stdio.
Expand Down Expand Up @@ -185,7 +187,7 @@ The MCP server provides **18 specialized tools** organized by analysis patterns.
*SQL performance analysis and execution plan comparison*
| πŸ”§ Tool | πŸ“ Description |
|---------|----------------|
| `list_slowest_sql_queries` | 🐌 Get the top N slowest SQL queries for an application with detailed execution metrics |
| `list_slowest_sql_queries` | 🐌 Get the top N slowest SQL queries for an application with detailed execution metrics and optional plan descriptions |
| `compare_sql_execution_plans` | πŸ” Compare SQL execution plans between two Spark jobs, analyzing logical/physical plans and execution metrics |

### 🚨 Performance & Bottleneck Analysis
Expand Down Expand Up @@ -302,6 +304,7 @@ SHS_SERVERS_*_AUTH_TOKEN - Token for a specific server
SHS_SERVERS_*_VERIFY_SSL - Whether to verify SSL for a specific server (true/false)
SHS_SERVERS_*_TIMEOUT - HTTP request timeout in seconds for a specific server (default: 30)
SHS_SERVERS_*_EMR_CLUSTER_ARN - EMR cluster ARN for a specific server
SHS_SERVERS_*_INCLUDE_PLAN_DESCRIPTION - Whether to include SQL execution plans by default for a specific server (true/false, default: false)
```

## πŸ€– AI Agent Integration
Expand Down
1 change: 1 addition & 0 deletions src/spark_history_mcp/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ServerConfig(BaseSettings):
emr_cluster_arn: Optional[str] = None # EMR specific field
use_proxy: bool = False
timeout: int = 30 # HTTP request timeout in seconds
disabled_tools: List[str] = Field(default_factory=list) # Tools to disable


class McpConfig(BaseSettings):
Expand Down
37 changes: 19 additions & 18 deletions src/spark_history_mcp/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TaskMetricDistributions,
)

from ..utils.tool_filter import conditional_tool
from ..utils.utils import parallel_execute

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_client_or_default(
)


@mcp.tool()
@conditional_tool(mcp)
def list_applications(
server: Optional[str] = None,
status: Optional[list[str]] = None,
Expand Down Expand Up @@ -130,7 +131,7 @@ def list_applications(
return all_apps


@mcp.tool()
@conditional_tool(mcp)
def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInfo:
"""
Get detailed information about a specific Spark application.
Expand All @@ -151,7 +152,7 @@ def get_application(app_id: str, server: Optional[str] = None) -> ApplicationInf
return client.get_application(app_id)


@mcp.tool()
@conditional_tool(mcp)
def list_jobs(
app_id: str, server: Optional[str] = None, status: Optional[list[str]] = None
) -> list:
Expand All @@ -177,7 +178,7 @@ def list_jobs(
return client.list_jobs(app_id=app_id, status=job_statuses)


@mcp.tool()
@conditional_tool(mcp)
def list_slowest_jobs(
app_id: str,
server: Optional[str] = None,
Expand Down Expand Up @@ -222,7 +223,7 @@ def get_job_duration(job):
return heapq.nlargest(n, jobs, key=get_job_duration)


@mcp.tool()
@conditional_tool(mcp)
def list_stages(
app_id: str,
server: Optional[str] = None,
Expand Down Expand Up @@ -259,7 +260,7 @@ def list_stages(
)


@mcp.tool()
@conditional_tool(mcp)
def list_slowest_stages(
app_id: str,
server: Optional[str] = None,
Expand Down Expand Up @@ -302,7 +303,7 @@ def get_stage_duration(stage: StageData):
return heapq.nlargest(n, stages, key=get_stage_duration)


@mcp.tool()
@conditional_tool(mcp)
def get_stage(
app_id: str,
stage_id: int,
Expand Down Expand Up @@ -368,7 +369,7 @@ def get_stage(
return stage_data


@mcp.tool()
@conditional_tool(mcp)
def get_environment(app_id: str, server: Optional[str] = None):
"""
Get the comprehensive Spark runtime configuration for a Spark application.
Expand All @@ -389,7 +390,7 @@ def get_environment(app_id: str, server: Optional[str] = None):
return client.get_environment(app_id=app_id)


@mcp.tool()
@conditional_tool(mcp)
def list_executors(
app_id: str, server: Optional[str] = None, include_inactive: bool = False
):
Expand All @@ -416,7 +417,7 @@ def list_executors(
return client.list_executors(app_id=app_id)


@mcp.tool()
@conditional_tool(mcp)
def get_executor(app_id: str, executor_id: str, server: Optional[str] = None):
"""
Get information about a specific executor.
Expand Down Expand Up @@ -445,7 +446,7 @@ def get_executor(app_id: str, executor_id: str, server: Optional[str] = None):
return None


@mcp.tool()
@conditional_tool(mcp)
def get_executor_summary(app_id: str, server: Optional[str] = None):
"""
Aggregates metrics across all executors for a Spark application.
Expand All @@ -467,7 +468,7 @@ def get_executor_summary(app_id: str, server: Optional[str] = None):
return _calculate_executor_metrics(executors)


@mcp.tool()
@conditional_tool(mcp)
def compare_job_environments(
app_id1: str, app_id2: str, server: Optional[str] = None
) -> Dict[str, Any]:
Expand Down Expand Up @@ -582,7 +583,7 @@ def _calc_executor_summary_from_client(client, app_id: str):
return _calculate_executor_metrics(executors)


@mcp.tool()
@conditional_tool(mcp)
def compare_job_performance(
app_id1: str, app_id2: str, server: Optional[str] = None
) -> Dict[str, Any]:
Expand Down Expand Up @@ -736,7 +737,7 @@ def calc_job_stats(jobs):
return comparison


@mcp.tool()
@conditional_tool(mcp)
def compare_sql_execution_plans(
app_id1: str,
app_id2: str,
Expand Down Expand Up @@ -857,7 +858,7 @@ def analyze_nodes(execution):
return comparison


@mcp.tool()
@conditional_tool(mcp)
def get_stage_task_summary(
app_id: str,
stage_id: int,
Expand Down Expand Up @@ -914,7 +915,7 @@ def truncate_plan_description(plan_desc: str, max_length: int) -> str:
return truncated + "\n... [truncated]"


@mcp.tool()
@conditional_tool(mcp)
def list_slowest_sql_queries(
app_id: str,
server: Optional[str] = None,
Expand Down Expand Up @@ -1010,7 +1011,7 @@ def list_slowest_sql_queries(
return simplified_results


@mcp.tool()
@conditional_tool(mcp)
def get_job_bottlenecks(
app_id: str, server: Optional[str] = None, top_n: int = 5
) -> Dict[str, Any]:
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def get_job_bottlenecks(
return bottlenecks


@mcp.tool()
@conditional_tool(mcp)
def get_resource_usage_timeline(
app_id: str, server: Optional[str] = None
) -> Dict[str, Any]:
Expand Down
79 changes: 79 additions & 0 deletions src/spark_history_mcp/utils/tool_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Tool filtering utilities for conditional MCP tool registration."""

import logging
import os
from typing import Callable, Optional, TypeVar

from spark_history_mcp.config.config import Config

F = TypeVar("F", bound=Callable)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def is_tool_enabled(tool_name: str, config_path: str = "config.yaml") -> bool:
"""
Check if a tool is enabled based on configuration and environment variables.

Args:
tool_name: Name of the tool to check
config_path: Path to configuration file (default: "config.yaml")

Returns:
bool: True if tool is enabled, False if disabled
"""
# Check environment variable first (highest priority)
env_var = f"SHS_DISABLE_{tool_name.upper()}"
if os.getenv(env_var, "").lower() in ("true", "1", "yes"):
return False

# Check global environment variable for disabled tools
disabled_tools_env = os.getenv("SHS_GLOBAL_DISABLED_TOOLS", "")
if disabled_tools_env:
disabled_tools = [tool.strip() for tool in disabled_tools_env.split(",")]
if tool_name in disabled_tools:
return False

# Check configuration file
try:
config = Config.from_file(config_path)

# Check if any server has this tool disabled
for server_config in config.servers.values():
if tool_name in server_config.disabled_tools:
return False

except Exception as e:
logger.error(f"Error loading configuration and loading disabled tools: {e}")
return True
Comment on lines +34 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to use the settings object for this instead?



def conditional_tool(
mcp_instance, tool_name: Optional[str] = None, config_path: str = "config.yaml"
):
"""
Decorator that conditionally registers an MCP tool based on configuration.

Args:
mcp_instance: The FastMCP instance to register tools with
tool_name: Name of the tool (defaults to function name)
config_path: Path to configuration file

Returns:
Decorator function
"""

def decorator(func: F) -> F:
actual_tool_name = tool_name or func.__name__

if is_tool_enabled(actual_tool_name, config_path):
# Tool is enabled, register it with MCP
return mcp_instance.tool()(func)
else:
# Tool is disabled, return unregistered function
return func

return decorator
Loading