Skip to content

Commit 23178fe

Browse files
committed
feat(prof): update for local_peak gpu memory metric (#2284)
Updating `openvm-prof` to handle the new metrics from openvm-org/stark-backend#184 I have tested it locally.
1 parent 36daff5 commit 23178fe

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

crates/prof/src/lib.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ impl MetricDb {
131131
pub fn generate_gpu_memory_chart(&self) -> Option<String> {
132132
// (timestamp, tracked_gb, reserved_gb, device_gb)
133133
let mut data: Vec<(f64, f64, f64, f64)> = Vec::new();
134-
// module -> [(tracked_gb, context_label)]
134+
// module -> [(local_peak_gb, context_label)]
135135
let mut module_stats: HashMap<String, Vec<(f64, String)>> = HashMap::new();
136136

137137
for (label_keys, metrics_dict) in &self.dict_by_label_types {
@@ -143,17 +143,17 @@ impl MetricDb {
143143
for (label_values, metrics) in metrics_dict {
144144
let get = |name: &str| metrics.iter().find(|m| m.name == name).map(|m| m.value);
145145
let ts = get("gpu_mem.timestamp_ms");
146-
let tracked = get("gpu_mem.tracked_bytes");
146+
let current = get("gpu_mem.current_bytes");
147+
let local_peak = get("gpu_mem.local_peak_bytes");
147148
let reserved = get("gpu_mem.reserved_bytes");
148-
let device = get("gpu_mem.device_bytes");
149149

150-
if let (Some(ts), Some(tracked), Some(reserved), Some(device)) =
151-
(ts, tracked, reserved, device)
150+
if let (Some(ts), Some(current), Some(local_peak), Some(reserved)) =
151+
(ts, current, local_peak, reserved)
152152
{
153-
let tracked_gb = tracked / 1e9;
153+
let current_gb = current / 1e9;
154+
let local_peak_gb = local_peak / 1e9;
154155
let reserved_gb = reserved / 1e9;
155-
let device_gb = device / 1e9;
156-
data.push((ts, tracked_gb, reserved_gb, device_gb));
156+
data.push((ts, current_gb, local_peak_gb, reserved_gb));
157157

158158
let module_name = label_values.get(module_idx).cloned().unwrap_or_default();
159159
let context_label: String = label_keys
@@ -167,7 +167,7 @@ impl MetricDb {
167167
module_stats
168168
.entry(module_name)
169169
.or_default()
170-
.push((tracked_gb, context_label));
170+
.push((local_peak_gb, context_label));
171171
}
172172
}
173173
}
@@ -178,10 +178,10 @@ impl MetricDb {
178178

179179
data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
180180

181-
let max_tracked = data.iter().map(|(_, t, _, _)| *t).fold(0.0_f64, f64::max);
182-
let max_reserved = data.iter().map(|(_, _, r, _)| *r).fold(0.0_f64, f64::max);
183-
let max_device = data.iter().map(|(_, _, _, d)| *d).fold(0.0_f64, f64::max);
184-
let chart_max = max_tracked.max(max_reserved).max(max_device);
181+
let max_current = data.iter().map(|(_, c, _, _)| *c).fold(0.0_f64, f64::max);
182+
let max_local_peak = data.iter().map(|(_, _, lp, _)| *lp).fold(0.0_f64, f64::max);
183+
let max_reserved = data.iter().map(|(_, _, _, r)| *r).fold(0.0_f64, f64::max);
184+
let chart_max = max_current.max(max_local_peak).max(max_reserved);
185185

186186
let mut chart = String::new();
187187
chart.push_str("```mermaid\n");
@@ -200,42 +200,42 @@ impl MetricDb {
200200
" y-axis \"Memory (GB)\" 0 --> {:.1}\n",
201201
chart_max * 1.1
202202
));
203-
// Tracked memory line (blue)
203+
// Current memory line (blue)
204204
chart.push_str(" line [");
205205
chart.push_str(
206206
&data
207207
.iter()
208-
.map(|(_, tracked, _, _)| format!("{:.2}", tracked))
208+
.map(|(_, current, _, _)| format!("{:.2}", current))
209209
.collect::<Vec<_>>()
210210
.join(", "),
211211
);
212212
chart.push_str("]\n");
213-
// Reserved memory line (green)
213+
// Local peak memory line (green)
214214
chart.push_str(" line [");
215215
chart.push_str(
216216
&data
217217
.iter()
218-
.map(|(_, _, reserved, _)| format!("{:.2}", reserved))
218+
.map(|(_, _, local_peak, _)| format!("{:.2}", local_peak))
219219
.collect::<Vec<_>>()
220220
.join(", "),
221221
);
222222
chart.push_str("]\n");
223-
// Device memory line (red)
223+
// Reserved memory line (red)
224224
chart.push_str(" line [");
225225
chart.push_str(
226226
&data
227227
.iter()
228-
.map(|(_, _, _, device)| format!("{:.2}", device))
228+
.map(|(_, _, _, reserved)| format!("{:.2}", reserved))
229229
.collect::<Vec<_>>()
230230
.join(", "),
231231
);
232232
chart.push_str("]\n");
233233
chart.push_str("```\n");
234234

235235
chart.push_str("\n> ");
236-
chart.push_str("🔵 Tracked (Current) | ");
237-
chart.push_str("🟢 Reserved (Pool) | ");
238-
chart.push_str("🔴 Device\n");
236+
chart.push_str("🔵 Current | ");
237+
chart.push_str("🟢 Local Peak | ");
238+
chart.push_str("🔴 Reserved (Pool)\n");
239239

240240
// Per-module stats table
241241
chart.push_str("\n| Module | Max (GB) | Max At |\n");

0 commit comments

Comments
 (0)