Skip to content

Commit 5ddeecf

Browse files
committed
stable reinmax
1 parent 66b573b commit 5ddeecf

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def inner(model, *args, **kwargs):
5151
# sampling helpers
5252

5353
def log(t, eps = 1e-20):
54-
return torch.log(t + eps)
54+
return torch.log(t.clamp(min = eps))
5555

5656
def gumbel_noise(t):
5757
noise = torch.zeros_like(t).uniform_(0, 1)
@@ -239,7 +239,7 @@ def forward(
239239
one_hot = one_hot.detach()
240240
π0 = logits.softmax(dim = 1)
241241
π1 = (one_hot + (logits / temp).softmax(dim = 1)) / 2
242-
π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1)
242+
π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1)
243243
π2 = 2 * π1 - 0.5 * π0
244244
one_hot = π2 - π2.detach() + one_hot
245245

dalle_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.5'
1+
__version__ = '1.6.6'

0 commit comments

Comments
 (0)