Skip to content

Commit c19498c

Browse files
authored
Use --config flag everywhere (#127)
Signed-off-by: Damon P. Cortesi <[email protected]>
1 parent a0e727b commit c19498c

File tree

5 files changed

+75
-25
lines changed

5 files changed

+75
-25
lines changed

src/spark_history_mcp/config/config.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,50 @@
11
import os
2-
from typing import Dict, List, Literal, Optional
2+
from typing import Any, Dict, List, Literal, Optional, Tuple
33

44
import yaml
55
from pydantic import Field
6+
from pydantic.fields import FieldInfo
67
from pydantic_settings import (
78
BaseSettings,
89
PydanticBaseSettingsSource,
910
SettingsConfigDict,
1011
)
1112

1213

14+
class YamlConfigSettingsSource(PydanticBaseSettingsSource):
15+
"""Custom settings source that loads configuration from a YAML file.
16+
17+
The file path is determined by the SHS_MCP_CONFIG environment variable,
18+
defaulting to 'config.yaml' if not set.
19+
"""
20+
21+
def get_field_value(
22+
self, field: FieldInfo, field_name: str
23+
) -> Tuple[Any, str, bool]:
24+
# Not used for this implementation
25+
return None, field_name, False
26+
27+
def __call__(self) -> Dict[str, Any]:
28+
"""Load and return the YAML configuration data."""
29+
config_path = os.getenv("SHS_MCP_CONFIG", "config.yaml")
30+
is_explicitly_set = "SHS_MCP_CONFIG" in os.environ
31+
32+
if not os.path.exists(config_path):
33+
# If the config file was explicitly specified but doesn't exist, fail fast
34+
if is_explicitly_set:
35+
raise FileNotFoundError(
36+
f"Config file not found: {config_path}\n"
37+
f"Specified via: SHS_MCP_CONFIG environment variable"
38+
)
39+
# If using default and it doesn't exist, return empty (will use defaults)
40+
return {}
41+
42+
with open(config_path, "r") as f:
43+
config_data = yaml.safe_load(f)
44+
45+
return config_data or {}
46+
47+
1348
class AuthConfig(BaseSettings):
1449
"""Authentication configuration for the Spark server."""
1550

@@ -57,17 +92,6 @@ class Config(BaseSettings):
5792
env_file_encoding="utf-8",
5893
)
5994

60-
@classmethod
61-
def from_file(cls, file_path: str) -> "Config":
62-
"""Load configuration from a YAML file."""
63-
if not os.path.exists(file_path):
64-
return Config()
65-
66-
with open(file_path, "r") as f:
67-
config_data = yaml.safe_load(f)
68-
69-
return cls.model_validate(config_data)
70-
7195
@classmethod
7296
def settings_customise_sources(
7397
cls,
@@ -77,4 +101,16 @@ def settings_customise_sources(
77101
dotenv_settings: PydanticBaseSettingsSource,
78102
file_secret_settings: PydanticBaseSettingsSource,
79103
) -> tuple[PydanticBaseSettingsSource, ...]:
80-
return env_settings, dotenv_settings, init_settings, file_secret_settings
104+
# Precedence order (highest to lowest):
105+
# 1. Environment variables
106+
# 2. .env file
107+
# 3. YAML config file (from SHS_MCP_CONFIG)
108+
# 4. Init settings (constructor arguments)
109+
# 5. File secrets
110+
return (
111+
env_settings,
112+
dotenv_settings,
113+
YamlConfigSettingsSource(settings_cls),
114+
init_settings,
115+
file_secret_settings,
116+
)

src/spark_history_mcp/core/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def default(self, obj):
3333

3434
@asynccontextmanager
3535
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
36-
config = Config.from_file("config.yaml")
36+
# Config() automatically loads from SHS_MCP_CONFIG env var (set in main.py)
37+
config = Config()
3738

3839
clients: dict[str, SparkRestClient] = {}
3940
default_client = None

src/spark_history_mcp/core/main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,19 @@ def main():
3030
try:
3131
logger.info("Starting Spark History Server MCP...")
3232
logger.info(f"Using config file: {args.config}")
33-
config = Config.from_file(args.config)
33+
34+
# Set the config file path in environment for Pydantic Settings
35+
os.environ["SHS_MCP_CONFIG"] = args.config
36+
37+
# Now Config() will automatically load from the specified YAML file
38+
config = Config()
3439
if config.mcp.debug:
3540
logger.setLevel(logging.DEBUG)
3641
logger.debug(json.dumps(json.loads(config.model_dump_json()), indent=4))
3742
app.run(config)
43+
except FileNotFoundError as e:
44+
logger.error(f"Configuration error: {e}")
45+
sys.exit(1)
3846
except Exception as e:
3947
logger.error(f"Failed to start MCP server: {e}")
4048
sys.exit(1)

tests/emr/test_emr_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def test_spark_client_with_emr_session(self, mock_initialize):
6363
self.assertEqual(apps, [])
6464

6565
@patch("spark_history_mcp.core.app.EMRPersistentUIClient")
66-
@patch("spark_history_mcp.core.app.Config.from_file")
66+
@patch("spark_history_mcp.core.app.Config")
6767
def test_app_lifespan_with_emr_config(
68-
self, mock_config_from_file, mock_emr_client_class
68+
self, mock_config_class, mock_emr_client_class
6969
):
7070
"""Test app_lifespan context manager with EMR configuration."""
7171
import asyncio
@@ -97,7 +97,7 @@ def test_app_lifespan_with_emr_config(
9797
emr_cluster_arn=self.emr_cluster_arn, default=True, verify_ssl=True
9898
)
9999
}
100-
mock_config_from_file.return_value = mock_config
100+
mock_config_class.return_value = mock_config
101101

102102
# Use the app_lifespan context manager
103103
async def test_lifespan():

tests/unit/config.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def test_config_from_file(self):
3939
temp_file_path = temp_file.name
4040

4141
try:
42-
# Load config from the file
43-
config = Config.from_file(temp_file_path)
42+
# Load config from the file using SHS_MCP_CONFIG env var
43+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
44+
config = Config()
4445

4546
# Verify the loaded configuration
4647
self.assertEqual(config.mcp.address, "test_host")
@@ -63,9 +64,10 @@ def test_config_from_file(self):
6364
os.unlink(temp_file_path)
6465

6566
def test_nonexistent_config_file(self):
66-
"""Test behavior when config file doesn't exist."""
67+
"""Test behavior when explicitly specified config file doesn't exist."""
6768
with self.assertRaises(FileNotFoundError):
68-
Config.from_file("nonexistent_file.yaml")
69+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": "nonexistent_file.yaml"}):
70+
Config()
6971

7072
@patch.dict(
7173
os.environ,
@@ -89,7 +91,8 @@ def test_config_from_env_vars(self):
8991
temp_file_path = temp_file.name
9092

9193
try:
92-
config = Config.from_file(temp_file_path)
94+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
95+
config = Config()
9396

9497
# Verify MCP config from env vars
9598
self.assertEqual(config.mcp.address, "env_host")
@@ -122,7 +125,8 @@ def test_env_vars_override_file_config(self):
122125
temp_file_path = temp_file.name
123126

124127
try:
125-
config = Config.from_file(temp_file_path)
128+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
129+
config = Config()
126130

127131
# Verify that env vars override file config
128132
self.assertEqual(config.mcp.address, "override_host")
@@ -147,7 +151,8 @@ def test_default_values(self):
147151
temp_file_path = temp_file.name
148152

149153
try:
150-
config = Config.from_file(temp_file_path)
154+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
155+
config = Config()
151156

152157
# Check MCP defaults
153158
self.assertEqual(config.mcp.address, "localhost")

0 commit comments

Comments
 (0)