33import warnings
44
55import altair as alt
6+ import numpy as np
7+ import pandas as pd
68import solara
9+ from matplotlib .colors import to_rgb
710
11+ import mesa
812from mesa .discrete_space import DiscreteSpace , Grid
9- from mesa .space import ContinuousSpace , _Grid
13+ from mesa .space import ContinuousSpace , PropertyLayer , _Grid
1014from mesa .visualization .utils import update_counter
1115
1216
@@ -20,13 +24,16 @@ def make_space_altair(*args, **kwargs): # noqa: D103
2024
2125
2226def make_altair_space (
23- agent_portrayal , propertylayer_portrayal , post_process , ** space_drawing_kwargs
27+ agent_portrayal ,
28+ propertylayer_portrayal = None ,
29+ post_process = None ,
30+ ** space_drawing_kwargs ,
2431):
2532 """Create an Altair-based space visualization component.
2633
2734 Args:
2835 agent_portrayal: Function to portray agents.
29- propertylayer_portrayal: not yet implemented
36+ propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
3037 post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks)
3138 space_drawing_kwargs : not yet implemented
3239
@@ -43,14 +50,23 @@ def agent_portrayal(a):
4350 return {"id" : a .unique_id }
4451
4552 def MakeSpaceAltair (model ):
46- return SpaceAltair (model , agent_portrayal , post_process = post_process )
53+ return SpaceAltair (
54+ model ,
55+ agent_portrayal ,
56+ propertylayer_portrayal = propertylayer_portrayal ,
57+ post_process = post_process ,
58+ )
4759
4860 return MakeSpaceAltair
4961
5062
5163@solara .component
5264def SpaceAltair (
53- model , agent_portrayal , dependencies : list [any ] | None = None , post_process = None
65+ model ,
66+ agent_portrayal ,
67+ propertylayer_portrayal = None ,
68+ dependencies : list [any ] | None = None ,
69+ post_process = None ,
5470):
5571 """Create an Altair-based space visualization component.
5672
@@ -63,10 +79,11 @@ def SpaceAltair(
6379 # Sometimes the space is defined as model.space instead of model.grid
6480 space = model .space
6581
66- chart = _draw_grid (space , agent_portrayal )
82+ chart = _draw_grid (space , agent_portrayal , propertylayer_portrayal )
6783 # Apply post-processing if provided
6884 if post_process is not None :
6985 chart = post_process (chart )
86+
7087 solara .FigureAltair (chart )
7188
7289
@@ -138,7 +155,7 @@ def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal):
138155 return all_agent_data
139156
140157
141- def _draw_grid (space , agent_portrayal ):
158+ def _draw_grid (space , agent_portrayal , propertylayer_portrayal ):
142159 match space :
143160 case Grid ():
144161 all_agent_data = _get_agent_data_new_discrete_space (space , agent_portrayal )
@@ -168,23 +185,266 @@ def _draw_grid(space, agent_portrayal):
168185 }
169186 has_color = "color" in all_agent_data [0 ]
170187 if has_color :
171- encoding_dict ["color" ] = alt .Color ("color" , type = "nominal" )
188+ unique_colors = list ({agent ["color" ] for agent in all_agent_data })
189+ encoding_dict ["color" ] = alt .Color (
190+ "color:N" ,
191+ scale = alt .Scale (domain = unique_colors , range = unique_colors ),
192+ )
172193 has_size = "size" in all_agent_data [0 ]
173194 if has_size :
174195 encoding_dict ["size" ] = alt .Size ("size" , type = "quantitative" )
175196
176- chart = (
197+ agent_chart = (
177198 alt .Chart (
178199 alt .Data (values = all_agent_data ), encoding = alt .Encoding (** encoding_dict )
179200 )
180201 .mark_point (filled = True )
181- .properties (width = 280 , height = 280 )
182- # .configure_view(strokeOpacity=0) # hide grid/chart lines
202+ .properties (width = 300 , height = 300 )
183203 )
184- # This is the default value for the marker size, which auto-scales
185- # according to the grid area.
204+ base_chart = None
205+ cbar_chart = None
206+
207+ # This is the default value for the marker size, which auto-scales according to the grid area.
186208 if not has_size :
187209 length = min (space .width , space .height )
188- chart = chart .mark_point (size = 30000 / length ** 2 , filled = True )
210+ agent_chart = agent_chart .mark_point (size = 30000 / length ** 2 , filled = True )
211+
212+ if propertylayer_portrayal is not None :
213+ chart_width = agent_chart .properties ().width
214+ chart_height = agent_chart .properties ().height
215+ base_chart , cbar_chart = chart_property_layers (
216+ space = space ,
217+ propertylayer_portrayal = propertylayer_portrayal ,
218+ chart_width = chart_width ,
219+ chart_height = chart_height ,
220+ )
221+
222+ base_chart = alt .layer (base_chart , agent_chart )
223+ else :
224+ base_chart = agent_chart
225+ if cbar_chart is not None :
226+ base_chart = alt .vconcat (base_chart , cbar_chart ).configure_view (stroke = None )
227+ return base_chart
228+
229+
230+ def chart_property_layers (space , propertylayer_portrayal , chart_width , chart_height ):
231+ """Creates Property Layers in the Altair Components.
232+
233+ Args:
234+ space: the ContinuousSpace instance
235+ propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
236+ chart_width: width of the agent chart to maintain consistency with the property charts
237+ chart_height: height of the agent chart to maintain consistency with the property charts
238+ agent_chart: the agent chart to layer with the property layers on the grid
239+ Returns:
240+ Altair Chart
241+ """
242+ try :
243+ # old style spaces
244+ property_layers = space .properties
245+ except AttributeError :
246+ # new style spaces
247+ property_layers = space ._mesa_property_layers
248+ base = None
249+ bar_chart = None
250+ for layer_name , portrayal in propertylayer_portrayal .items ():
251+ layer = property_layers .get (layer_name , None )
252+ if not isinstance (
253+ layer ,
254+ PropertyLayer | mesa .discrete_space .property_layer .PropertyLayer ,
255+ ):
256+ continue
189257
190- return chart
258+ data = layer .data .astype (float ) if layer .data .dtype == bool else layer .data
259+
260+ if (space .width , space .height ) != data .shape :
261+ warnings .warn (
262+ f"Layer { layer_name } dimensions ({ data .shape } ) do not match space dimensions ({ space .width } , { space .height } )." ,
263+ UserWarning ,
264+ stacklevel = 2 ,
265+ )
266+ alpha = portrayal .get ("alpha" , 1 )
267+ vmin = portrayal .get ("vmin" , np .min (data ))
268+ vmax = portrayal .get ("vmax" , np .max (data ))
269+ colorbar = portrayal .get ("colorbar" , True )
270+
271+ # Prepare data for Altair (convert 2D array to a long-form DataFrame)
272+ df = pd .DataFrame (
273+ {
274+ "x" : np .repeat (np .arange (data .shape [0 ]), data .shape [1 ]),
275+ "y" : np .tile (np .arange (data .shape [1 ]), data .shape [0 ]),
276+ "value" : data .flatten (),
277+ }
278+ )
279+
280+ if "color" in portrayal :
281+ # Create a function to map values to RGBA colors with proper opacity scaling
282+ def apply_rgba (val , vmin = vmin , vmax = vmax , alpha = alpha , portrayal = portrayal ):
283+ """Maps data values to RGBA colors with opacity based on value magnitude.
284+
285+ Args:
286+ val: The data value to convert
287+ vmin: The smallest value for which the color is displayed in the colorbar
288+ vmax: The largest value for which the color is displayed in the colorbar
289+ alpha: The opacity of the color
290+ portrayal: The specifics of the current property layer in the iterative loop
291+
292+ Returns:
293+ String representation of RGBA color
294+ """
295+ # Normalize value to range [0,1] and clamp
296+ normalized = max (0 , min ((val - vmin ) / (vmax - vmin ), 1 ))
297+
298+ # Scale opacity by alpha parameter
299+ opacity = normalized * alpha
300+
301+ # Convert color to RGB components
302+ rgb_color = to_rgb (portrayal ["color" ])
303+ r = int (rgb_color [0 ] * 255 )
304+ g = int (rgb_color [1 ] * 255 )
305+ b = int (rgb_color [2 ] * 255 )
306+
307+ return f"rgba({ r } , { g } , { b } , { opacity :.2f} )"
308+
309+ # Apply color mapping to each value in the dataset
310+ df ["color" ] = df ["value" ].apply (apply_rgba )
311+
312+ # Create chart for the property layer
313+ chart = (
314+ alt .Chart (df )
315+ .mark_rect ()
316+ .encode (
317+ x = alt .X ("x:O" , axis = None ),
318+ y = alt .Y ("y:O" , axis = None ),
319+ fill = alt .Fill ("color:N" , scale = None ),
320+ )
321+ .properties (width = chart_width , height = chart_height , title = layer_name )
322+ )
323+ base = alt .layer (chart , base ) if base is not None else chart
324+
325+ # Add colorbar if specified in portrayal
326+ if colorbar :
327+ # Extract RGB components from base color
328+ rgb_color = to_rgb (portrayal ["color" ])
329+ r_int = int (rgb_color [0 ] * 255 )
330+ g_int = int (rgb_color [1 ] * 255 )
331+ b_int = int (rgb_color [2 ] * 255 )
332+
333+ # Define gradient endpoints
334+ min_color = f"rgba({ r_int } ,{ g_int } ,{ b_int } ,0)"
335+ max_color = f"rgba({ r_int } ,{ g_int } ,{ b_int } ,{ alpha :.2f} )"
336+
337+ # Define colorbar dimensions
338+ colorbar_height = 20
339+ colorbar_width = chart_width
340+
341+ # Create dataframe for gradient visualization
342+ df_gradient = pd .DataFrame ({"x" : [0 , 1 ], "y" : [0 , 1 ]})
343+
344+ # Create evenly distributed tick values
345+ axis_values = np .linspace (vmin , vmax , 11 )
346+ tick_positions = np .linspace (0 , colorbar_width , 11 )
347+
348+ # Prepare data for axis and labels
349+ axis_data = pd .DataFrame ({"value" : axis_values , "x" : tick_positions })
350+
351+ # Create colorbar with linear gradient
352+ colorbar_chart = (
353+ alt .Chart (df_gradient )
354+ .mark_rect (
355+ x = 0 ,
356+ y = 0 ,
357+ width = colorbar_width ,
358+ height = colorbar_height ,
359+ color = alt .Gradient (
360+ gradient = "linear" ,
361+ stops = [
362+ alt .GradientStop (color = min_color , offset = 0 ),
363+ alt .GradientStop (color = max_color , offset = 1 ),
364+ ],
365+ x1 = 0 ,
366+ x2 = 1 , # Horizontal gradient
367+ y1 = 0 ,
368+ y2 = 0 , # Keep y constant
369+ ),
370+ )
371+ .encode (
372+ x = alt .value (chart_width / 2 ), # Center colorbar
373+ y = alt .value (0 ),
374+ )
375+ .properties (width = colorbar_width , height = colorbar_height )
376+ )
377+
378+ # Add tick marks to colorbar
379+ axis_chart = (
380+ alt .Chart (axis_data )
381+ .mark_tick (thickness = 2 , size = 8 )
382+ .encode (x = alt .X ("x:Q" , axis = None ), y = alt .value (colorbar_height - 2 ))
383+ )
384+
385+ # Add value labels below tick marks
386+ text_labels = (
387+ alt .Chart (axis_data )
388+ .mark_text (baseline = "top" , fontSize = 10 , dy = 0 )
389+ .encode (
390+ x = alt .X ("x:Q" ),
391+ text = alt .Text ("value:Q" , format = ".1f" ),
392+ y = alt .value (colorbar_height + 10 ),
393+ )
394+ )
395+
396+ # Add title to colorbar
397+ title = (
398+ alt .Chart (pd .DataFrame ([{"text" : layer_name }]))
399+ .mark_text (
400+ fontSize = 12 ,
401+ fontWeight = "bold" ,
402+ baseline = "bottom" ,
403+ align = "center" ,
404+ )
405+ .encode (
406+ text = "text:N" ,
407+ x = alt .value (colorbar_width / 2 ),
408+ y = alt .value (colorbar_height + 40 ),
409+ )
410+ )
411+
412+ # Combine all colorbar components
413+ combined_colorbar = alt .layer (
414+ colorbar_chart , axis_chart , text_labels , title
415+ ).properties (width = colorbar_width , height = colorbar_height + 50 )
416+
417+ bar_chart = (
418+ alt .vconcat (bar_chart , combined_colorbar )
419+ .resolve_scale (color = "independent" )
420+ .configure_view (stroke = None )
421+ if bar_chart is not None
422+ else combined_colorbar
423+ )
424+
425+ elif "colormap" in portrayal :
426+ cmap = portrayal .get ("colormap" , "viridis" )
427+ cmap_scale = alt .Scale (scheme = cmap , domain = [vmin , vmax ])
428+
429+ chart = (
430+ alt .Chart (df )
431+ .mark_rect (opacity = alpha )
432+ .encode (
433+ x = alt .X ("x:O" , axis = None ),
434+ y = alt .Y ("y:O" , axis = None ),
435+ color = alt .Color (
436+ "value:Q" ,
437+ scale = cmap_scale ,
438+ title = layer_name ,
439+ legend = alt .Legend (title = layer_name ) if colorbar else None ,
440+ ),
441+ )
442+ .properties (width = chart_width , height = chart_height )
443+ )
444+ base = alt .layer (chart , base ) if base is not None else chart
445+
446+ else :
447+ raise ValueError (
448+ f"PropertyLayer { layer_name } portrayal must include 'color' or 'colormap'."
449+ )
450+ return base , bar_chart
0 commit comments