Skip to content

Commit d03df6a

Browse files
committed
add an assert on torch version in order to use OpenAIDiscreteVAE
1 parent 2612a51 commit d03df6a

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

dalle_pytorch/vae.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pathlib import Path
1111
from tqdm import tqdm
1212
from math import sqrt, log
13+
from packaging import version
14+
1315
from omegaconf import OmegaConf
1416
from taming.models.vqgan import VQModel, GumbelVQ
1517
import importlib
@@ -98,11 +100,18 @@ def make_contiguous(module):
98100
for param in module.parameters():
99101
param.set_(param.contiguous())
100102

103+
# package versions
104+
105+
def get_pkg_version(pkg_name):
106+
from pkg_resources import get_distribution
107+
return get_distribution(pkg_name).version
108+
101109
# pretrained Discrete VAE from OpenAI
102110

103111
class OpenAIDiscreteVAE(nn.Module):
104112
def __init__(self):
105113
super().__init__()
114+
assert version.parse(get_pkg_version('torch')) < version.parse('1.11.0'), 'torch version must be <= 1.10 in order to use OpenAI discrete vae'
106115

107116
self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
108117
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.6.1',
7+
version = '1.6.2',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)