22#
33# SPDX-License-Identifier: MPL-2.0
44
5- from collections .abc import Callable
6- import asyncio
75import os
6+ import asyncio
87import threading
8+ from contextlib import asynccontextmanager
9+ from typing import Any
10+ from collections .abc import Callable
11+
912import uvicorn
1013from fastapi import FastAPI
1114from fastapi .responses import FileResponse
1215from fastapi_socketio import SocketManager
16+
1317from arduino .app_utils import brick , Logger
1418
1519logger = Logger ("WebUI" )
@@ -32,7 +36,8 @@ def __init__(
3236 api_path_prefix : str = "" ,
3337 assets_dir_path : str = "/app/assets" ,
3438 certs_dir_path : str = "/app/certs" ,
35- use_ssl : bool = False ,
39+ use_tls : bool = False ,
40+ use_ssl : bool | None = None , # Deprecated alias for use_tls
3641 ):
3742 """Initialize the web server.
3843
@@ -42,35 +47,72 @@ def __init__(
4247 ui_path_prefix (str, optional): URL prefix for UI routes. Defaults to "" (root).
4348 api_path_prefix (str, optional): URL prefix for API routes. Defaults to "" (root).
4449 assets_dir_path (str, optional): Path to static assets directory. Defaults to "/app/assets".
45- certs_dir_path (str, optional): Path to SSL certificates directory. Defaults to "/app/certs".
46- use_ssl (bool, optional): Enable SSL/HTTPS. Defaults to False.
50+ certs_dir_path (str, optional): Path to TLS certificates directory. Defaults to "/app/certs".
51+ use_tls (bool, optional): Enable TLS/HTTPS. Defaults to False.
52+ use_ssl (bool, optional): Deprecated. Use use_tls instead. Defaults to None.
4753 """
48- self .app = FastAPI (title = __name__ , openapi_url = None , on_startup = [self ._on_startup ])
54+ # Handle deprecated use_ssl parameter
55+ if use_ssl is not None :
56+ logger .warning ("'use_ssl' parameter is deprecated. Use 'use_tls' instead." )
57+ use_tls = use_ssl
58+
59+ @asynccontextmanager
60+ async def lifespan (app ):
61+ await self ._on_startup ()
62+ yield
63+
64+ self .app = FastAPI (title = __name__ , openapi_url = None , lifespan = lifespan )
4965 self .sio = SocketManager (app = self .app , mount_location = "/socket.io" , socketio_path = "" , max_http_buffer_size = 10 * 1024 * 1024 )
5066
5167 self ._addr = addr
52- self ._port = port
68+
69+ def pick_free_port ():
70+ import socket
71+
72+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
73+ s .bind (("127.0.0.1" , 0 ))
74+ return s .getsockname ()[1 ]
75+
76+ self ._port = port if port != 0 else pick_free_port ()
5377 self ._ui_path_prefix = ui_path_prefix
5478 self ._api_path_prefix = api_path_prefix
5579 self ._assets_dir_path = os .path .abspath (assets_dir_path )
5680 self ._certs_dir_path = os .path .abspath (certs_dir_path )
57- self ._use_ssl = use_ssl
58- self ._protocol = "https" if self ._use_ssl else "http"
59- self ._server : uvicorn .Server = None
81+ self ._use_tls = use_tls
82+ self ._protocol = "https" if self ._use_tls else "http"
83+ self ._server : uvicorn .Server | None = None
6084 self ._server_loop : asyncio .AbstractEventLoop | None = None
61- self ._on_connect_cb : Callable [[str ], None ] = None
62- self ._on_disconnect_cb : Callable [[str ], None ] = None
85+ self ._on_connect_cb : Callable [[str ], None ] | None = None
86+ self ._on_disconnect_cb : Callable [[str ], None ] | None = None
6387 self ._on_message_cbs = {}
6488 self ._on_message_cbs_lock = threading .Lock ()
6589
90+ @property
91+ def local_url (self ) -> str :
92+ """Get the locally addressable URL of the web server.
93+
94+ Returns:
95+ str: The server's URL (including protocol, address, and port).
96+ """
97+ return f"{ self ._protocol } ://localhost:{ self ._port } "
98+
99+ @property
100+ def url (self ) -> str :
101+ """Get the externally addressable URL of the web server.
102+
103+ Returns:
104+ str: The server's URL (including protocol, address, and port).
105+ """
106+ return f"{ self ._protocol } ://{ os .getenv ('HOST_IP' ) or self ._addr } :{ self ._port } "
107+
66108 def start (self ):
67109 """Start the web server asynchronously.
68110
69- This sets up static file routing and WebSocket event handlers, configures SSL if enabled, and launches the server using Uvicorn.
111+ This sets up static file routing and WebSocket event handlers, configures TLS if enabled, and launches the server using Uvicorn.
70112
71113 Raises:
72114 RuntimeError: If 'index.html' is missing in the static assets directory.
73- RuntimeError: If SSL is enabled but certificates are missing or fail to generate.
115+ RuntimeError: If TLS is enabled but certificates fail to generate.
74116 RuntimeWarning: If the server is already running.
75117 """
76118 # Setup static routes and SocketIO events
@@ -82,18 +124,15 @@ def start(self):
82124 self ._init_socketio ()
83125
84126 config = uvicorn .Config (self .app , host = self ._addr , port = self ._port , log_level = "warning" )
85- if self ._use_ssl :
86- from . import certs
127+ if self ._use_tls :
128+ from arduino . app_utils . tls_cert_manager import TLSCertificateManager
87129
88- if not certs .cert_exists (self ._certs_dir_path ):
89- try :
90- certs .generate_self_signed_cert (self ._certs_dir_path )
91- except Exception as e :
92- logger .exception (f"Failed to generate SSL certificate: { e } " )
93- raise RuntimeError ("Failed to generate SSL certificate. Please check the certs directory." ) from e
94-
95- config .ssl_keyfile = certs .get_pkey (self ._certs_dir_path )
96- config .ssl_certfile = certs .get_cert (self ._certs_dir_path )
130+ try :
131+ cert_path , key_path = TLSCertificateManager .get_or_create_certificates (certs_dir = self ._certs_dir_path , common_name = self ._addr )
132+ config .ssl_certfile = cert_path
133+ config .ssl_keyfile = key_path
134+ except Exception as e :
135+ raise RuntimeError ("Failed to configure TLS certificate. Please check the certs directory." ) from e
97136
98137 self ._server = uvicorn .Server (config )
99138
@@ -108,33 +147,31 @@ def stop(self):
108147
109148 def execute (self ):
110149 logger .debug (f"Serving static web files from { self ._assets_dir_path } " )
111- if self ._use_ssl :
112- logger .debug (f"Serving certificates from { self ._certs_dir_path } " )
150+ if self ._use_tls :
151+ logger .debug (f"Using TLS certificates from { self ._certs_dir_path } " )
113152
114153 logger .debug ("Starting server..." )
115154
116155 startup_log = "The application interface is available here:\n "
117- startup_log += f" - Local URL: { self ._protocol } ://localhost:{ self ._port } "
118- host_ip = os .getenv ("HOST_IP" )
119- if host_ip :
120- network_url = f"{ self ._protocol } ://{ host_ip } :{ self ._port } "
121- startup_log += f"\n - Network URL: { network_url } "
156+ startup_log += f" - Local URL: { self .local_url } "
157+ if os .getenv ("HOST_IP" ):
158+ startup_log += f"\n - Network URL: { self .url } "
122159 logger .info (startup_log )
123160
124161 try :
125162 self ._server .run ()
126163 except Exception as e :
127164 logger .exception (f"Error running server: { e } " )
128165
129- def expose_api (self , method : str , path : str , function : callable ):
166+ def expose_api (self , method : str , path : str , function : Callable ):
130167 """Register a route with the specified HTTP method and path.
131168
132169 The path will be prefixed with the api_path_prefix configured during initialization.
133170
134171 Args:
135172 method (str): HTTP method to use (e.g., "GET", "POST").
136173 path (str): URL path for the API endpoint (without the prefix).
137- function (callable ): Function to execute when the route is accessed.
174+ function (Callable ): Function to execute when the route is accessed.
138175 """
139176 self .app .add_api_route (self ._api_path_prefix + path , function , methods = [method ])
140177
@@ -160,7 +197,7 @@ def on_disconnect(self, callback: Callable[[str], None]):
160197 """
161198 self ._on_disconnect_cb = callback
162199
163- def on_message (self , message_type : str , callback : Callable [[str , any ], any ]):
200+ def on_message (self , message_type : str , callback : Callable [[str , Any ], Any ]):
164201 """Register a callback function for a specific WebSocket message type received by clients.
165202
166203 The client should send messages named as message_type for this callback to be triggered.
@@ -170,7 +207,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]):
170207
171208 Args:
172209 message_type (str): The message type name to listen for.
173- callback (Callable[[str, any ], any ]): Function to handle the message. Receives two arguments:
210+ callback (Callable[[str, Any ], Any ]): Function to handle the message. Receives two arguments:
174211 the session ID (sid) and the incoming message data.
175212
176213 """
@@ -180,7 +217,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]):
180217 self ._on_message_cbs [message_type ] = callback
181218 logger .debug (f"Registered listener for message '{ message_type } '" )
182219
183- def send_message (self , message_type : str , message : dict | str , room : str = None ):
220+ def send_message (self , message_type : str , message : dict | str , room : str | None = None ):
184221 """Send a message to connected WebSocket clients.
185222
186223 Args:
@@ -200,7 +237,8 @@ def send_message(self, message_type: str, message: dict | str, room: str = None)
200237 logger .exception (f"Failed to send WebSocket message '{ message_type } ': { e } " )
201238
202239 async def _on_startup (self ):
203- """This function is called by uvicorn when the server starts up, it is necessary to capture the running
240+ """
241+ This function is called by uvicorn when the server starts up, it is necessary to capture the running
204242 asyncio event loop and reuse it later for emitting socket.io events as it requires an asyncio context.
205243 """
206244 self ._server_loop = asyncio .get_running_loop ()
0 commit comments