Skip to content

Commit f17251b

Browse files
authored
Account for the VRAM cost of weight offloading (#10733)
* mm: default to 0 for NUM_STREAMS Dont count the compute stream as an offload stream. This makes async offload accounting easier. * mm: remove 128MB minimum This is from a previous offloading system requirement. Remove it to make behaviour of the loader and partial unloader consistent. * mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. * mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model.
1 parent c38e7d6 commit f17251b

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

comfy/model_management.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
689689
loaded_memory = loaded_model.model_loaded_memory()
690690
current_free_mem = get_free_memory(torch_dev) + loaded_memory
691691

692-
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
692+
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
693693
lowvram_model_memory = lowvram_model_memory - loaded_memory
694694

695695
if lowvram_model_memory == 0:
@@ -1012,7 +1012,7 @@ def force_channels_last():
10121012

10131013

10141014
STREAMS = {}
1015-
NUM_STREAMS = 1
1015+
NUM_STREAMS = 0
10161016
if args.async_offload:
10171017
NUM_STREAMS = 2
10181018
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
@@ -1030,7 +1030,7 @@ def current_stream(device):
10301030
stream_counters = {}
10311031
def get_offload_stream(device):
10321032
stream_counter = stream_counters.get(device, 0)
1033-
if NUM_STREAMS <= 1:
1033+
if NUM_STREAMS == 0:
10341034
return None
10351035

10361036
if device in STREAMS:

comfy/model_patcher.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ def __call__(self, weight):
148148
else:
149149
return out
150150

151+
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
152+
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
153+
154+
def low_vram_patch_estimate_vram(model, key):
155+
weight, set_func, convert_func = get_key_weight(model, key)
156+
if weight is None:
157+
return 0
158+
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
159+
151160
def get_key_weight(model, key):
152161
set_func = None
153162
convert_func = None
@@ -269,6 +278,9 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
269278
if not hasattr(self.model, 'current_weight_patches_uuid'):
270279
self.model.current_weight_patches_uuid = None
271280

281+
if not hasattr(self.model, 'model_offload_buffer_memory'):
282+
self.model.model_offload_buffer_memory = 0
283+
272284
def model_size(self):
273285
if self.size > 0:
274286
return self.size
@@ -662,7 +674,16 @@ def _load_list(self):
662674
skip = True # skip random weights in non leaf modules
663675
break
664676
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
665-
loading.append((comfy.model_management.module_size(m), n, m, params))
677+
module_mem = comfy.model_management.module_size(m)
678+
module_offload_mem = module_mem
679+
if hasattr(m, "comfy_cast_weights"):
680+
weight_key = "{}.weight".format(n)
681+
bias_key = "{}.bias".format(n)
682+
if weight_key in self.patches:
683+
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
684+
if bias_key in self.patches:
685+
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
686+
loading.append((module_offload_mem, module_mem, n, m, params))
666687
return loading
667688

668689
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -676,20 +697,22 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
676697

677698
load_completely = []
678699
offloaded = []
700+
offload_buffer = 0
679701
loading.sort(reverse=True)
680702
for x in loading:
681-
n = x[1]
682-
m = x[2]
683-
params = x[3]
684-
module_mem = x[0]
703+
module_offload_mem, module_mem, n, m, params = x
685704

686705
lowvram_weight = False
687706

707+
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
708+
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
709+
688710
weight_key = "{}.weight".format(n)
689711
bias_key = "{}.bias".format(n)
690712

691713
if not full_load and hasattr(m, "comfy_cast_weights"):
692-
if mem_counter + module_mem >= lowvram_model_memory:
714+
if not lowvram_fits:
715+
offload_buffer = potential_offload
693716
lowvram_weight = True
694717
lowvram_counter += 1
695718
lowvram_mem_counter += module_mem
@@ -723,9 +746,11 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
723746
if hasattr(m, "comfy_cast_weights"):
724747
wipe_lowvram_weight(m)
725748

726-
if full_load or mem_counter + module_mem < lowvram_model_memory:
749+
if full_load or lowvram_fits:
727750
mem_counter += module_mem
728751
load_completely.append((module_mem, n, m, params))
752+
else:
753+
offload_buffer = potential_offload
729754

730755
if cast_weight and hasattr(m, "comfy_cast_weights"):
731756
m.prev_comfy_cast_weights = m.comfy_cast_weights
@@ -766,7 +791,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
766791
self.pin_weight_to_device("{}.{}".format(n, param))
767792

768793
if lowvram_counter > 0:
769-
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
794+
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
770795
self.model.model_lowvram = True
771796
else:
772797
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
@@ -778,6 +803,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
778803
self.model.lowvram_patch_counter += patch_counter
779804
self.model.device = device_to
780805
self.model.model_loaded_weight_memory = mem_counter
806+
self.model.model_offload_buffer_memory = offload_buffer
781807
self.model.current_weight_patches_uuid = self.patches_uuid
782808

783809
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@@ -831,6 +857,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
831857
self.model.to(device_to)
832858
self.model.device = device_to
833859
self.model.model_loaded_weight_memory = 0
860+
self.model.model_offload_buffer_memory = 0
834861

835862
for m in self.model.modules():
836863
if hasattr(m, "comfy_patched_weights"):
@@ -849,13 +876,14 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
849876
patch_counter = 0
850877
unload_list = self._load_list()
851878
unload_list.sort()
879+
offload_buffer = self.model.model_offload_buffer_memory
880+
852881
for unload in unload_list:
853-
if memory_to_free < memory_freed:
882+
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
854883
break
855-
module_mem = unload[0]
856-
n = unload[1]
857-
m = unload[2]
858-
params = unload[3]
884+
module_offload_mem, module_mem, n, m, params = unload
885+
886+
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
859887

860888
lowvram_possible = hasattr(m, "comfy_cast_weights")
861889
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -906,15 +934,18 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
906934
m.comfy_cast_weights = True
907935
m.comfy_patched_weights = False
908936
memory_freed += module_mem
937+
offload_buffer = max(offload_buffer, potential_offload)
909938
logging.debug("freed {}".format(n))
910939

911940
for param in params:
912941
self.pin_weight_to_device("{}.{}".format(n, param))
913942

943+
914944
self.model.model_lowvram = True
915945
self.model.lowvram_patch_counter += patch_counter
916946
self.model.model_loaded_weight_memory -= memory_freed
917-
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
947+
self.model.model_offload_buffer_memory = offload_buffer
948+
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
918949
return memory_freed
919950

920951
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

0 commit comments

Comments
 (0)