@@ -57,13 +57,17 @@ def __init__(self):
5757 self .process = psutil .Process ()
5858 self .nvml_initialized = False
5959 self .gpu_handle = None
60+ self .baseline_gpu_memory_mb = 0
6061
6162 if NVML_AVAILABLE and nvml :
6263 try :
6364 nvml .nvmlInit ()
6465 self .nvml_initialized = True
6566 # Get the first GPU
6667 self .gpu_handle = nvml .nvmlDeviceGetHandleByIndex (0 )
68+ # Store baseline GPU memory usage
69+ mem_info = nvml .nvmlDeviceGetMemoryInfo (self .gpu_handle )
70+ self .baseline_gpu_memory_mb = mem_info .used / 1024 / 1024
6771 except Exception as e :
6872 print (f"Failed to initialize NVML: { e } " )
6973 self .nvml_initialized = False
@@ -113,34 +117,55 @@ def get_memory_usage(self) -> dict[str, float]:
113117 except Exception as e :
114118 print (f"Error getting RAM usage: { e } " )
115119
116- # Get total GPU memory usage (not per-process)
120+ # Get GPU memory usage - use total system VRAM since extensions run in separate processes
117121 if self .nvml_initialized and self .gpu_handle :
118122 try :
119123 # Get total GPU memory info
120124 mem_info = nvml .nvmlDeviceGetMemoryInfo (self .gpu_handle )
121- memory_info ["gpu_used_mb" ] = mem_info .used / 1024 / 1024
125+ current_used_mb = mem_info .used / 1024 / 1024
126+ memory_info ["gpu_used_mb" ] = current_used_mb
122127 memory_info ["gpu_total_mb" ] = mem_info .total / 1024 / 1024
123- memory_info ["total_vram_mb" ] = memory_info [ "gpu_used_mb" ]
128+ memory_info ["total_vram_mb" ] = current_used_mb
124129
125- # Try to get per-process info (might return 0)
126- try :
127- processes = nvml .nvmlDeviceGetComputeRunningProcesses (self .gpu_handle )
128- our_pids = set (self .get_process_tree_pids ())
129-
130- for proc in processes :
131- if proc .pid in our_pids and proc .usedGpuMemory is not None :
132- vram_mb = proc .usedGpuMemory / 1024 / 1024
133- if proc .pid == self .process .pid :
134- memory_info ["host_vram_mb" ] = vram_mb
135- except Exception as e :
136- # Per-process tracking failed, use total GPU memory instead
137- print (f"Warning: Per-process GPU tracking failed: { e } " , file = sys .stderr )
130+ # Calculate VRAM usage relative to baseline (captures all processes)
131+ # This is more reliable than per-process tracking, especially on Windows
132+ vram_delta = current_used_mb - self .baseline_gpu_memory_mb
133+ memory_info ["host_vram_mb" ] = max (0 , vram_delta )
138134
139135 except Exception as e :
140136 print (f"Error getting GPU memory usage: { e } " )
141137
138+ # Fallback: try PyTorch CUDA memory for current process if NVML failed
139+ elif CUDA_AVAILABLE and torch .cuda .is_available ():
140+ try :
141+ # This only captures current process, but better than nothing
142+ allocated_mb = torch .cuda .memory_allocated () / 1024 / 1024
143+ reserved_mb = torch .cuda .memory_reserved () / 1024 / 1024
144+
145+ memory_info ["host_vram_mb" ] = allocated_mb
146+ memory_info ["total_vram_mb" ] = allocated_mb
147+ memory_info ["pytorch_reserved_mb" ] = reserved_mb
148+
149+ print (
150+ "Warning: Using PyTorch CUDA memory (current process only): "
151+ + f"{ allocated_mb :.1f} MB allocated" ,
152+ file = sys .stderr ,
153+ )
154+
155+ except Exception as e :
156+ print (f"Error getting PyTorch CUDA memory: { e } " )
157+
142158 return memory_info
143159
160+ def reset_baseline (self ):
161+ """Reset the baseline GPU memory measurement."""
162+ if self .nvml_initialized and self .gpu_handle :
163+ try :
164+ mem_info = nvml .nvmlDeviceGetMemoryInfo (self .gpu_handle )
165+ self .baseline_gpu_memory_mb = mem_info .used / 1024 / 1024
166+ except Exception as e :
167+ print (f"Error resetting GPU memory baseline: { e } " )
168+
144169 def __del__ (self ):
145170 """Cleanup NVML on deletion."""
146171 if self .nvml_initialized :
0 commit comments