Skip to content

Commit 50ca97e

Browse files
Speed up lora compute and lower memory usage by doing it in fp16. (#11161)
1 parent 7ac7d69 commit 50ca97e

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

comfy/model_management.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,20 @@ def extended_fp16_support():
14921492

14931493
return True
14941494

1495+
LORA_COMPUTE_DTYPES = {}
1496+
def lora_compute_dtype(device):
1497+
dtype = LORA_COMPUTE_DTYPES.get(device, None)
1498+
if dtype is not None:
1499+
return dtype
1500+
1501+
if should_use_fp16(device):
1502+
dtype = torch.float16
1503+
else:
1504+
dtype = torch.float32
1505+
1506+
LORA_COMPUTE_DTYPES[device] = dtype
1507+
return dtype
1508+
14951509
def soft_empty_cache(force=False):
14961510
global cpu_state
14971511
if cpu_state == CPUState.MPS:

comfy/model_patcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,11 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
614614
if key not in self.backup:
615615
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
616616

617+
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
617618
if device_to is not None:
618-
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
619+
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
619620
else:
620-
temp_weight = weight.to(torch.float32, copy=True)
621+
temp_weight = weight.to(temp_dtype, copy=True)
621622
if convert_func is not None:
622623
temp_weight = convert_func(temp_weight, inplace=True)
623624

0 commit comments

Comments
 (0)