Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
OCL = 3

# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
Expand Down Expand Up @@ -101,6 +102,14 @@ def get_supported_float8_types():
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.

ocl_available = False
try:
import pytorch_ocl
import torch.ocl
ocl_available = True
except ImportError:
pass

try:
import intel_extension_for_pytorch as ipex # noqa: F401
except:
Expand Down Expand Up @@ -138,6 +147,10 @@ def get_supported_float8_types():
except:
ixuca_available = False

if ocl_available:
# TODO gate behind flag.
cpu_state = CPUState.OCL

if args.cpu:
cpu_state = CPUState.CPU

Expand Down Expand Up @@ -167,6 +180,12 @@ def is_ixuca():
return True
return False

def is_ocl():
global ocl_available
if ocl_available:
return True
return False

def get_torch_device():
global directml_enabled
global cpu_state
Expand All @@ -177,6 +196,8 @@ def get_torch_device():
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
if cpu_state == CPUState.OCL:
return torch.device("ocl:0")
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
Expand All @@ -192,7 +213,7 @@ def get_total_memory(dev=None, torch_total_too=False):
if dev is None:
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
Expand All @@ -217,6 +238,9 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
elif is_ocl():
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
Expand Down Expand Up @@ -1231,7 +1255,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
Expand Down Expand Up @@ -1259,6 +1283,15 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
elif is_ocl():
# stats = torch.ocl.memory_stats(dev)
# mem_active = stats['active_bytes.all.current']
# mem_reserved = stats['reserved_bytes.all.current']
# mem_free_ocl, _ = torch.ocl.mem_get_info(dev)
# mem_free_torch = mem_reserved - mem_active
# mem_free_total = mem_free_mlu + mem_free_torch
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
Expand Down Expand Up @@ -1337,6 +1370,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True

if is_ocl():
# TODO ? RustiCL now supports fp16 at least.
return True

if is_ixuca():
return True

Expand Down Expand Up @@ -1413,6 +1450,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False

if is_ocl():
# TODO
return True

props = torch.cuda.get_device_properties(device)

if is_mlu():
Expand Down