diff --git a/src/plots/all_gsps.py b/src/plots/all_gsps.py
index 4d2fd1d..81f889f 100644
--- a/src/plots/all_gsps.py
+++ b/src/plots/all_gsps.py
@@ -1,6 +1,6 @@
from plotly import graph_objects as go
-from plots.utils import line_color
+from plots.utils import get_colour_from_model_name, colour_per_model
def make_all_gsps_plots(x_mae_all_gsp, y_mae_all_gsp):
@@ -12,13 +12,17 @@ def make_all_gsps_plots(x_mae_all_gsp, y_mae_all_gsp):
legend=go.layout.Legend(title=go.layout.legend.Title(text="Chart Legend")),
)
)
+
+ model_name = "All_GSPs"
+ color = get_colour_from_model_name(model_name)
+
fig7.add_traces(
go.Scatter(
x=x_mae_all_gsp,
y=y_mae_all_gsp,
mode="lines",
name="Daily Latest MAE All GSPs",
- line=dict(color=line_color[4]),
+ line=dict(color=color),
),
)
return fig7
diff --git a/src/plots/forecast_horizon.py b/src/plots/forecast_horizon.py
index fcee7b3..c06d7cb 100644
--- a/src/plots/forecast_horizon.py
+++ b/src/plots/forecast_horizon.py
@@ -1,7 +1,7 @@
import pandas as pd
from plotly import graph_objects as go
-from plots.utils import line_color, get_x_y, MAE_LIMIT_DEFAULT
+from plots.utils import get_x_y, MAE_LIMIT_DEFAULT, get_colour_from_model_name
def make_mae_by_forecast_horizon(
@@ -19,6 +19,7 @@ def make_mae_by_forecast_horizon(
legend=go.layout.Legend(title=go.layout.legend.Title(text="Chart Legend")),
)
)
+
fig2.add_trace(
go.Scatter(
x=df_mae["datetime_utc"],
@@ -39,6 +40,8 @@ def make_mae_by_forecast_horizon(
}
)
+ color = get_colour_from_model_name(f"forecast_horizon_{i}")
+
fig2.add_traces(
[
go.Scatter(
@@ -46,7 +49,7 @@ def make_mae_by_forecast_horizon(
y=df["MAE"],
name=f"{forecast_horizon}-minute horizon",
mode="lines",
- line=dict(color=line_color[i%len(line_color)]),
+ line=dict(color=color),
)
]
)
@@ -80,6 +83,8 @@ def make_mae_forecast_horizon_group_by_forecast_horizon(
}
)
+ color = get_colour_from_model_name(f"forecast_horizon_{i}")
+
fig.add_traces(
[
go.Scatter(
@@ -87,7 +92,7 @@ def make_mae_forecast_horizon_group_by_forecast_horizon(
y=df_mae_horizon["forecast_horizon"],
name=f"{forecast_horizon}-minute horizon",
mode="markers",
- line=dict(color=line_color[i%len(line_color)]),
+ line=dict(color=color),
),
]
)
@@ -149,13 +154,16 @@ def make_mae_vs_forecast_horizon_group_by_date(
# sort results by day
results_for_day = result[i]
results_for_day = results_for_day.sort_values(by=["forecast_horizon"], ascending=True)
+
+ color = get_colour_from_model_name(f"forecast_date_{i}")
+
traces.append(
go.Scatter(
x=results_for_day["forecast_horizon"].sort_values(ascending=True),
y=results_for_day["MAE"],
name=results_for_day["datetime_utc"].iloc[0].strftime("%Y-%m-%d"),
mode="lines+markers",
- line=dict(color=line_color[i % len(line_color)]),
+ line=dict(color=color),
)
)
fig.add_traces(traces)
diff --git a/src/plots/mae_and_rmse.py b/src/plots/mae_and_rmse.py
index 9b9ba98..de4cd3f 100644
--- a/src/plots/mae_and_rmse.py
+++ b/src/plots/mae_and_rmse.py
@@ -1,6 +1,6 @@
from plotly import graph_objects as go, express as px
-from plots.utils import line_color, MAE_LIMIT_DEFAULT_HORIZON_0, MAE_LIMIT_DEFAULT
+from plots.utils import get_colour_from_model_name, MAE_LIMIT_DEFAULT_HORIZON_0, MAE_LIMIT_DEFAULT
def make_rmse_and_mae_plot(df_mae, df_rmse, x_plive_mae, x_plive_rmse, y_plive_mae, y_plive_rmse):
@@ -12,6 +12,10 @@ def make_rmse_and_mae_plot(df_mae, df_rmse, x_plive_mae, x_plive_rmse, y_plive_m
legend=go.layout.Legend(title=go.layout.legend.Title(text="Chart Legend")),
)
)
+
+ mae_color = get_colour_from_model_name("MAE")
+ rmse_color = get_colour_from_model_name("RMSE")
+
fig.add_traces(
[
go.Scatter(
@@ -19,28 +23,28 @@ def make_rmse_and_mae_plot(df_mae, df_rmse, x_plive_mae, x_plive_rmse, y_plive_m
y=df_mae["MAE"],
name="MAE",
mode="lines",
- line=dict(color=line_color[0]),
+ line=dict(color=mae_color),
),
go.Scatter(
x=df_rmse["datetime_utc"],
y=df_rmse["RMSE"],
name="RMSE",
mode="lines",
- line=dict(color=line_color[1]),
+ line=dict(color=rmse_color),
),
go.Scatter(
x=x_plive_mae,
y=y_plive_mae,
name="MAE PVLive",
mode="lines",
- line=dict(color=line_color[0], dash="dash"),
+ line=dict(color=mae_color, dash="dash"),
),
go.Scatter(
x=x_plive_rmse,
y=y_plive_rmse,
name="RMSE PVLive",
mode="lines",
- line=dict(color=line_color[1], dash="dash"),
+ line=dict(color=rmse_color, dash="dash"),
),
]
)
@@ -57,7 +61,7 @@ def make_mae_plot(df_mae):
"
Its actually the MAE for the last forecast made, which is normally the same as the "
"0 minute forecast horizon",
hover_data=["MAE", "datetime_utc"],
- color_discrete_sequence=["#FFAC5F"],
+ color_discrete_sequence=[get_colour_from_model_name("MAE_bar")],
)
fig.update_layout(yaxis_range=[0, MAE_LIMIT_DEFAULT_HORIZON_0])
return fig
diff --git a/src/plots/pinball_and_exceedance_plots.py b/src/plots/pinball_and_exceedance_plots.py
index 4ff2b49..a1f777b 100644
--- a/src/plots/pinball_and_exceedance_plots.py
+++ b/src/plots/pinball_and_exceedance_plots.py
@@ -9,7 +9,7 @@
from nowcasting_datamodel.models.metric import MetricValue
from get_data import get_metric_value
-from .utils import line_color
+from .utils import get_colour_from_model_name
def make_pinball_or_exceedance_plot(
@@ -61,6 +61,8 @@ def make_pinball_or_exceedance_plot(
x_horizon = [value.datetime_interval.start_datetime_utc for value in metric_values]
y_horizon = [round(float(value.value), 2) for value in metric_values]
+ color = get_colour_from_model_name(f"{metric_name}_p{plevel}_{forecast_horizon}")
+
# add to plot
fig.add_traces(
[
@@ -69,7 +71,7 @@ def make_pinball_or_exceedance_plot(
y=y_horizon,
name=f"p{plevel}_{forecast_horizon}-minute horizon",
mode="lines",
- line=dict(color=line_color[i%len(line_color)]),
+ line=dict(color=color),
)
]
)
diff --git a/src/plots/ramp_rate.py b/src/plots/ramp_rate.py
index 6b6b3b0..dd06695 100644
--- a/src/plots/ramp_rate.py
+++ b/src/plots/ramp_rate.py
@@ -6,7 +6,7 @@
from nowcasting_datamodel.models.metric import MetricValue
from get_data import get_metric_value
-from .utils import line_color
+from .utils import get_colour_from_model_name
def make_ramp_rate_plot(
@@ -48,6 +48,8 @@ def make_ramp_rate_plot(
x_horizon = [value.datetime_interval.start_datetime_utc for value in metric_values]
y_horizon = [round(float(value.value), 2) for value in metric_values]
+ color = get_colour_from_model_name(f"{metric_name}_{forecast_horizon}")
+
# add to plot
fig.add_traces(
[
@@ -56,7 +58,7 @@ def make_ramp_rate_plot(
y=y_horizon,
name=f"{forecast_horizon}-minute horizon",
mode="lines",
- line=dict(color=line_color[forecast_horizon_selection.index(forecast_horizon)]),
+ line=dict(color=color),
)
]
)
diff --git a/src/plots/utils.py b/src/plots/utils.py
index 9a477e5..0165344 100644
--- a/src/plots/utils.py
+++ b/src/plots/utils.py
@@ -2,17 +2,9 @@
from nowcasting_datamodel.read.read_models import get_models
import os
from datetime import datetime, timedelta
+import plotly.express as px
-line_color = [
- "#9EC8FA",
- "#9AA1F9",
- "#FFAC5F",
- "#9F973A",
- "#7BCDF3",
- "#086788",
- "#63BCAF",
- "#4C9A8E",
-]
+PALETTE = px.colors.qualitative.Dark24
colour_per_model = {
"cnn": "#FFD053",
@@ -26,10 +18,6 @@
"PVLive GSP Sum Updated": "#FF9736",
}
-# Make a cycle for extra models not in colour_per_model
-# Skip first 3 colours as they are too similar to colours in colour_per_model
-line_color_cycle = cycle(line_color[3:])
-
def hex_to_rgb(value):
value = value.lstrip("#")
lv = len(value)
@@ -39,14 +27,15 @@ def hex_to_rgb(value):
def get_colour_from_model_name(model_name, opacity=1.0):
"""Get colour from model label"""
if "PVLive" in model_name:
- colour = colour_per_model.get(model_name, "#FFFFFF")
+ return colour_per_model.get(model_name, "#FFFFFF")
else:
# Some models have a space and a datetime
model_name_only = model_name.split(" ")[0]
if model_name_only in colour_per_model:
colour = colour_per_model[model_name_only]
else:
- colour = next(line_color_cycle)
+ idx = abs(hash(model_name_only)) % len(PALETTE)
+ colour = PALETTE[idx]
colour_per_model[model_name_only] = colour
return colour