22from typing import Dict , List , Literal , Optional
33
44import yaml
5- from pydantic import BaseModel , Field
5+ from pydantic import Field
6+ from pydantic_settings import (
7+ BaseSettings ,
8+ PydanticBaseSettingsSource ,
9+ SettingsConfigDict ,
10+ )
611
712
8- class AuthConfig (BaseModel ):
13+ class AuthConfig (BaseSettings ):
914 """Authentication configuration for the Spark server."""
1015
11- username : str = Field (None , alias = "username" )
12- password : str = Field (None , alias = "password" )
13- token : str = Field (None , alias = "token" )
16+ username : Optional [ str ] = Field (None )
17+ password : Optional [ str ] = Field (None )
18+ token : Optional [ str ] = Field (None )
1419
15- def __init__ (self , ** data ):
16- # Support environment variables for sensitive data
17- if not data .get ("username" ):
18- data ["username" ] = os .getenv ("SHS_SPARK_USERNAME" )
19- if not data .get ("password" ):
20- data ["password" ] = os .getenv ("SHS_SPARK_PASSWORD" )
21- if not data .get ("token" ):
22- data ["token" ] = os .getenv ("SHS_SPARK_TOKEN" )
23- super ().__init__ (** data )
2420
25-
26- class ServerConfig (BaseModel ):
21+ class ServerConfig (BaseSettings ):
2722 """Server configuration for the Spark server."""
2823
2924 url : Optional [str ] = None
30- auth : AuthConfig = Field (None , alias = "auth" )
31- default : bool = Field ( False , alias = "default" )
32- verify_ssl : bool = Field ( True , alias = "verify_ssl" )
25+ auth : AuthConfig = Field (default_factory = AuthConfig , exclude = True )
26+ default : bool = False
27+ verify_ssl : bool = True
3328 emr_cluster_arn : Optional [str ] = None # EMR specific field
3429
3530
36- class McpConfig (BaseModel ):
31+ class McpConfig (BaseSettings ):
3732 """Configuration for the MCP server."""
3833
3934 transports : List [Literal ["stdio" , "sse" , "streamable-http" ]] = Field (
4035 default_factory = list
4136 )
42- address : str = Field (default = "localhost" )
43- port : str = Field (default = "18888" )
44- debug : bool = Field (default = False )
37+ address : Optional [str ] = "localhost"
38+ port : Optional [int | str ] = "18888"
39+ debug : Optional [bool ] = False
40+ model_config = SettingsConfigDict (extra = "ignore" )
4541
4642
47- class Config (BaseModel ):
43+ class Config (BaseSettings ):
4844 """Configuration for the Spark client."""
4945
5046 servers : Dict [str , ServerConfig ]
5147 mcp : Optional [McpConfig ] = None
48+ model_config = SettingsConfigDict (
49+ env_prefix = "SHS_" ,
50+ env_nested_delimiter = "_" ,
51+ env_file = ".env" ,
52+ env_file_encoding = "utf-8" ,
53+ )
5254
5355 @classmethod
5456 def from_file (cls , file_path : str ) -> "Config" :
@@ -60,3 +62,14 @@ def from_file(cls, file_path: str) -> "Config":
6062 config_data = yaml .safe_load (f )
6163
6264 return cls .model_validate (config_data )
65+
66+ @classmethod
67+ def settings_customise_sources (
68+ cls ,
69+ settings_cls : type [BaseSettings ],
70+ init_settings : PydanticBaseSettingsSource ,
71+ env_settings : PydanticBaseSettingsSource ,
72+ dotenv_settings : PydanticBaseSettingsSource ,
73+ file_secret_settings : PydanticBaseSettingsSource ,
74+ ) -> tuple [PydanticBaseSettingsSource , ...]:
75+ return env_settings , dotenv_settings , init_settings , file_secret_settings
0 commit comments