Skip to content

Commit f1b9763

Browse files
authored
Add PropertyLayer visualisation with Altair (#2643)
1 parent e29fc80 commit f1b9763

File tree

2 files changed

+303
-22
lines changed

2 files changed

+303
-22
lines changed

mesa/visualization/components/altair_components.py

Lines changed: 275 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
import warnings
44

55
import altair as alt
6+
import numpy as np
7+
import pandas as pd
68
import solara
9+
from matplotlib.colors import to_rgb
710

11+
import mesa
812
from mesa.discrete_space import DiscreteSpace, Grid
9-
from mesa.space import ContinuousSpace, _Grid
13+
from mesa.space import ContinuousSpace, PropertyLayer, _Grid
1014
from mesa.visualization.utils import update_counter
1115

1216

@@ -20,13 +24,16 @@ def make_space_altair(*args, **kwargs): # noqa: D103
2024

2125

2226
def make_altair_space(
23-
agent_portrayal, propertylayer_portrayal, post_process, **space_drawing_kwargs
27+
agent_portrayal,
28+
propertylayer_portrayal=None,
29+
post_process=None,
30+
**space_drawing_kwargs,
2431
):
2532
"""Create an Altair-based space visualization component.
2633
2734
Args:
2835
agent_portrayal: Function to portray agents.
29-
propertylayer_portrayal: not yet implemented
36+
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
3037
post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks)
3138
space_drawing_kwargs : not yet implemented
3239
@@ -43,14 +50,23 @@ def agent_portrayal(a):
4350
return {"id": a.unique_id}
4451

4552
def MakeSpaceAltair(model):
46-
return SpaceAltair(model, agent_portrayal, post_process=post_process)
53+
return SpaceAltair(
54+
model,
55+
agent_portrayal,
56+
propertylayer_portrayal=propertylayer_portrayal,
57+
post_process=post_process,
58+
)
4759

4860
return MakeSpaceAltair
4961

5062

5163
@solara.component
5264
def SpaceAltair(
53-
model, agent_portrayal, dependencies: list[any] | None = None, post_process=None
65+
model,
66+
agent_portrayal,
67+
propertylayer_portrayal=None,
68+
dependencies: list[any] | None = None,
69+
post_process=None,
5470
):
5571
"""Create an Altair-based space visualization component.
5672
@@ -63,10 +79,11 @@ def SpaceAltair(
6379
# Sometimes the space is defined as model.space instead of model.grid
6480
space = model.space
6581

66-
chart = _draw_grid(space, agent_portrayal)
82+
chart = _draw_grid(space, agent_portrayal, propertylayer_portrayal)
6783
# Apply post-processing if provided
6884
if post_process is not None:
6985
chart = post_process(chart)
86+
7087
solara.FigureAltair(chart)
7188

7289

@@ -138,7 +155,7 @@ def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal):
138155
return all_agent_data
139156

140157

141-
def _draw_grid(space, agent_portrayal):
158+
def _draw_grid(space, agent_portrayal, propertylayer_portrayal):
142159
match space:
143160
case Grid():
144161
all_agent_data = _get_agent_data_new_discrete_space(space, agent_portrayal)
@@ -168,23 +185,266 @@ def _draw_grid(space, agent_portrayal):
168185
}
169186
has_color = "color" in all_agent_data[0]
170187
if has_color:
171-
encoding_dict["color"] = alt.Color("color", type="nominal")
188+
unique_colors = list({agent["color"] for agent in all_agent_data})
189+
encoding_dict["color"] = alt.Color(
190+
"color:N",
191+
scale=alt.Scale(domain=unique_colors, range=unique_colors),
192+
)
172193
has_size = "size" in all_agent_data[0]
173194
if has_size:
174195
encoding_dict["size"] = alt.Size("size", type="quantitative")
175196

176-
chart = (
197+
agent_chart = (
177198
alt.Chart(
178199
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
179200
)
180201
.mark_point(filled=True)
181-
.properties(width=280, height=280)
182-
# .configure_view(strokeOpacity=0) # hide grid/chart lines
202+
.properties(width=300, height=300)
183203
)
184-
# This is the default value for the marker size, which auto-scales
185-
# according to the grid area.
204+
base_chart = None
205+
cbar_chart = None
206+
207+
# This is the default value for the marker size, which auto-scales according to the grid area.
186208
if not has_size:
187209
length = min(space.width, space.height)
188-
chart = chart.mark_point(size=30000 / length**2, filled=True)
210+
agent_chart = agent_chart.mark_point(size=30000 / length**2, filled=True)
211+
212+
if propertylayer_portrayal is not None:
213+
chart_width = agent_chart.properties().width
214+
chart_height = agent_chart.properties().height
215+
base_chart, cbar_chart = chart_property_layers(
216+
space=space,
217+
propertylayer_portrayal=propertylayer_portrayal,
218+
chart_width=chart_width,
219+
chart_height=chart_height,
220+
)
221+
222+
base_chart = alt.layer(base_chart, agent_chart)
223+
else:
224+
base_chart = agent_chart
225+
if cbar_chart is not None:
226+
base_chart = alt.vconcat(base_chart, cbar_chart).configure_view(stroke=None)
227+
return base_chart
228+
229+
230+
def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_height):
231+
"""Creates Property Layers in the Altair Components.
232+
233+
Args:
234+
space: the ContinuousSpace instance
235+
propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
236+
chart_width: width of the agent chart to maintain consistency with the property charts
237+
chart_height: height of the agent chart to maintain consistency with the property charts
238+
agent_chart: the agent chart to layer with the property layers on the grid
239+
Returns:
240+
Altair Chart
241+
"""
242+
try:
243+
# old style spaces
244+
property_layers = space.properties
245+
except AttributeError:
246+
# new style spaces
247+
property_layers = space._mesa_property_layers
248+
base = None
249+
bar_chart = None
250+
for layer_name, portrayal in propertylayer_portrayal.items():
251+
layer = property_layers.get(layer_name, None)
252+
if not isinstance(
253+
layer,
254+
PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer,
255+
):
256+
continue
189257

190-
return chart
258+
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
259+
260+
if (space.width, space.height) != data.shape:
261+
warnings.warn(
262+
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).",
263+
UserWarning,
264+
stacklevel=2,
265+
)
266+
alpha = portrayal.get("alpha", 1)
267+
vmin = portrayal.get("vmin", np.min(data))
268+
vmax = portrayal.get("vmax", np.max(data))
269+
colorbar = portrayal.get("colorbar", True)
270+
271+
# Prepare data for Altair (convert 2D array to a long-form DataFrame)
272+
df = pd.DataFrame(
273+
{
274+
"x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
275+
"y": np.tile(np.arange(data.shape[1]), data.shape[0]),
276+
"value": data.flatten(),
277+
}
278+
)
279+
280+
if "color" in portrayal:
281+
# Create a function to map values to RGBA colors with proper opacity scaling
282+
def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
283+
"""Maps data values to RGBA colors with opacity based on value magnitude.
284+
285+
Args:
286+
val: The data value to convert
287+
vmin: The smallest value for which the color is displayed in the colorbar
288+
vmax: The largest value for which the color is displayed in the colorbar
289+
alpha: The opacity of the color
290+
portrayal: The specifics of the current property layer in the iterative loop
291+
292+
Returns:
293+
String representation of RGBA color
294+
"""
295+
# Normalize value to range [0,1] and clamp
296+
normalized = max(0, min((val - vmin) / (vmax - vmin), 1))
297+
298+
# Scale opacity by alpha parameter
299+
opacity = normalized * alpha
300+
301+
# Convert color to RGB components
302+
rgb_color = to_rgb(portrayal["color"])
303+
r = int(rgb_color[0] * 255)
304+
g = int(rgb_color[1] * 255)
305+
b = int(rgb_color[2] * 255)
306+
307+
return f"rgba({r}, {g}, {b}, {opacity:.2f})"
308+
309+
# Apply color mapping to each value in the dataset
310+
df["color"] = df["value"].apply(apply_rgba)
311+
312+
# Create chart for the property layer
313+
chart = (
314+
alt.Chart(df)
315+
.mark_rect()
316+
.encode(
317+
x=alt.X("x:O", axis=None),
318+
y=alt.Y("y:O", axis=None),
319+
fill=alt.Fill("color:N", scale=None),
320+
)
321+
.properties(width=chart_width, height=chart_height, title=layer_name)
322+
)
323+
base = alt.layer(chart, base) if base is not None else chart
324+
325+
# Add colorbar if specified in portrayal
326+
if colorbar:
327+
# Extract RGB components from base color
328+
rgb_color = to_rgb(portrayal["color"])
329+
r_int = int(rgb_color[0] * 255)
330+
g_int = int(rgb_color[1] * 255)
331+
b_int = int(rgb_color[2] * 255)
332+
333+
# Define gradient endpoints
334+
min_color = f"rgba({r_int},{g_int},{b_int},0)"
335+
max_color = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})"
336+
337+
# Define colorbar dimensions
338+
colorbar_height = 20
339+
colorbar_width = chart_width
340+
341+
# Create dataframe for gradient visualization
342+
df_gradient = pd.DataFrame({"x": [0, 1], "y": [0, 1]})
343+
344+
# Create evenly distributed tick values
345+
axis_values = np.linspace(vmin, vmax, 11)
346+
tick_positions = np.linspace(0, colorbar_width, 11)
347+
348+
# Prepare data for axis and labels
349+
axis_data = pd.DataFrame({"value": axis_values, "x": tick_positions})
350+
351+
# Create colorbar with linear gradient
352+
colorbar_chart = (
353+
alt.Chart(df_gradient)
354+
.mark_rect(
355+
x=0,
356+
y=0,
357+
width=colorbar_width,
358+
height=colorbar_height,
359+
color=alt.Gradient(
360+
gradient="linear",
361+
stops=[
362+
alt.GradientStop(color=min_color, offset=0),
363+
alt.GradientStop(color=max_color, offset=1),
364+
],
365+
x1=0,
366+
x2=1, # Horizontal gradient
367+
y1=0,
368+
y2=0, # Keep y constant
369+
),
370+
)
371+
.encode(
372+
x=alt.value(chart_width / 2), # Center colorbar
373+
y=alt.value(0),
374+
)
375+
.properties(width=colorbar_width, height=colorbar_height)
376+
)
377+
378+
# Add tick marks to colorbar
379+
axis_chart = (
380+
alt.Chart(axis_data)
381+
.mark_tick(thickness=2, size=8)
382+
.encode(x=alt.X("x:Q", axis=None), y=alt.value(colorbar_height - 2))
383+
)
384+
385+
# Add value labels below tick marks
386+
text_labels = (
387+
alt.Chart(axis_data)
388+
.mark_text(baseline="top", fontSize=10, dy=0)
389+
.encode(
390+
x=alt.X("x:Q"),
391+
text=alt.Text("value:Q", format=".1f"),
392+
y=alt.value(colorbar_height + 10),
393+
)
394+
)
395+
396+
# Add title to colorbar
397+
title = (
398+
alt.Chart(pd.DataFrame([{"text": layer_name}]))
399+
.mark_text(
400+
fontSize=12,
401+
fontWeight="bold",
402+
baseline="bottom",
403+
align="center",
404+
)
405+
.encode(
406+
text="text:N",
407+
x=alt.value(colorbar_width / 2),
408+
y=alt.value(colorbar_height + 40),
409+
)
410+
)
411+
412+
# Combine all colorbar components
413+
combined_colorbar = alt.layer(
414+
colorbar_chart, axis_chart, text_labels, title
415+
).properties(width=colorbar_width, height=colorbar_height + 50)
416+
417+
bar_chart = (
418+
alt.vconcat(bar_chart, combined_colorbar)
419+
.resolve_scale(color="independent")
420+
.configure_view(stroke=None)
421+
if bar_chart is not None
422+
else combined_colorbar
423+
)
424+
425+
elif "colormap" in portrayal:
426+
cmap = portrayal.get("colormap", "viridis")
427+
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])
428+
429+
chart = (
430+
alt.Chart(df)
431+
.mark_rect(opacity=alpha)
432+
.encode(
433+
x=alt.X("x:O", axis=None),
434+
y=alt.Y("y:O", axis=None),
435+
color=alt.Color(
436+
"value:Q",
437+
scale=cmap_scale,
438+
title=layer_name,
439+
legend=alt.Legend(title=layer_name) if colorbar else None,
440+
),
441+
)
442+
.properties(width=chart_width, height=chart_height)
443+
)
444+
base = alt.layer(chart, base) if base is not None else chart
445+
446+
else:
447+
raise ValueError(
448+
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
449+
)
450+
return base, bar_chart

0 commit comments

Comments
 (0)