Skip to content

Commit b8dded6

Browse files
Merge pull request #7 from charlesjlee/main
Refactor and add support for inpainting
2 parents 97b1fd6 + c46d50d commit b8dded6

File tree

2 files changed

+165
-112
lines changed

2 files changed

+165
-112
lines changed

README.md

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
11
# Get Access
2-
32
[labs.openai.com/waitlist](https://labs.openai.com/waitlist)
43

5-
- Go to https://labs.openai.com/
6-
- Open Network Tab in Developer Tools
7-
- Type a prompt and press "Generate"
8-
- Look for fetch to https://labs.openai.com/api/labs/tasks
9-
- In the request header look for authorization then get the Bearer Token
10-
11-
12-
# Usage
4+
# Installation
135
```bash
146
pip install dalle2
157
```
8+
9+
# Usage
10+
## Setup
11+
1. Go to https://labs.openai.com/
12+
1. Open Network Tab in Developer Tools
13+
1. Type a prompt and press "Generate"
14+
1. Look for fetch to https://labs.openai.com/api/labs/tasks
15+
1. In the request header look for authorization then get the Bearer Token
16+
1617
```python
1718
from dalle2 import Dalle2
18-
1919
dalle = Dalle2("sess-xxxxxxxxxxxxxxxxxxxxxxxxxxxx")
20-
generations = dalle.generate("portal to another dimension, digital art")
20+
```
2121

22+
## Generate images
23+
```python
24+
generations = dalle.generate("portal to another dimension, digital art")
2225
print(generations)
2326
```
2427

2528
```
26-
✔️ Task created with ID: task-f77yxcsdf3OEm and PROMPT: portal to another dimension, digital art
29+
✔️ Task created with ID: task-xsuhOthvBXLEjddn3ynyiiOR
2730
⌛ Waiting for task to finish...
31+
...task not completed yet
32+
...task not completed yet
33+
...task not completed yet
34+
...task not completed yet
2835
🙌 Task completed!
2936
3037
[
@@ -40,69 +47,110 @@ print(generations)
4047
'prompt_id': 'prompt-2CtaLQsgUbJHHDoJQy9Lul3T',
4148
'is_public': false
4249
},
43-
{
44-
'id': 'generation-hZWt2Nasrx8R0tJjbaROfKVy',
45-
'object': 'generation',
46-
'created': 1553332711,
47-
'generation_type': 'ImageGeneration',
48-
'generation': {
49-
'image_path': 'https://openailabsprodscus.blob.core.windows.net/private/user-hadpVzldsfs28CwvEZYMUT/generations/generation...'
50-
},
51-
'task_id': 'task-nERkiKhjasdSZ50yD69qewID',
52-
'prompt_id': 'prompt-2CtaLasdUbJHHfoJQy9Lul3T',
53-
'is_public': false
54-
},
55-
# 2 more ...
50+
# 3 more ...
5651
]
5752
```
5853

59-
or download all generations
60-
54+
## Download images
6155
```python
62-
from dalle2 import Dalle2
56+
file_paths = dalle.download(generations)
57+
print(file_paths)
58+
```
6359

64-
dalle = Dalle2("sess-xxxxxxxxxxxxxxxxxxxxxxxxxxxx")
65-
generations = dalle.generate_and_download("portal to another dimension, digital art")
60+
```
61+
✔️ Downloaded: C:\...\generation-XySidj4N8EN6Ok9ed15BZ2bs.png
62+
✔️ Downloaded: C:\...\generation-IK3UdxDz77FA5SLKpQPIITdU.png
63+
✔️ Downloaded: C:\...\generation-uNejKBXz1z6EQxJAT9pAZbof.png
64+
✔️ Downloaded: C:\...\generation-Ol1wEqNprf34vNohmJz0iUiE.png
65+
66+
[
67+
'C:/.../generation-pvi9TEWrhciLyFIlfgF1XUHF.png',
68+
'C:/.../generation-xp545V8jsqhSKKyJydHZPL50.png',
69+
'C:/.../generation-wNODqnBhvzYvXasonBn1anIA.png',
70+
'C:/.../generation-InPSaWWxpapT8TJD0kI71hNM.png'
71+
]
72+
```
6673

74+
## Generate images and download them
75+
```python
76+
file_paths = dalle.generate_and_download("portal to another dimension, digital art")
6777
```
6878

6979
```
70-
✔️ Task created with ID: task-f77sayxcSGdfOEm and PROMPT: portal to another dimension, digital art
80+
✔️ Task created with ID: task-xsuhOthvBXLEjddn3ynyiiOR
7181
⌛ Waiting for task to finish...
82+
...task not completed yet
83+
...task not completed yet
84+
...task not completed yet
85+
...task not completed yet
7286
🙌 Task completed!
73-
Download to directory: C:\Users\pc\dalle2
74-
✔️ Downloaded: generation-fAq4Lyxcm7pQVDBQEWJ.jpg
75-
✔️ Downloaded: generation-zqfBC3yyxcPXRlW6zLP.jpg
76-
✔️ Downloaded: generation-soR3ryxcoeixzdyHG.jpg
77-
✔️ Downloaded: generation-lT5L4yxc2DOiGRwJi.jpg
87+
✔️ Downloaded: C:\...\generation-XySidj4N8EN6Ok9ed15BZ2bs.png
88+
✔️ Downloaded: C:\...\generation-IK3UdxDz77FA5SLKpQPIITdU.png
89+
✔️ Downloaded: C:\...\generation-uNejKBXz1z6EQxJAT9pAZbof.png
90+
✔️ Downloaded: C:\...\generation-Ol1wEqNprf34vNohmJz0iUiE.png
7891
```
7992

80-
81-
or generate a specific amount
82-
93+
## Generate a specific number of images
8394
```python
84-
from dalle2 import Dalle2
85-
86-
dalle = Dalle2("sess-xxxxxxxxxxxxxxxxxxxxxxxxxxxx")
87-
generations = dalle.generate_amount("portal to another dimension", 12) # Every generation has batch size 4 -> amount % 4 == 0 works best
88-
89-
print(generations)
95+
generations = dalle.generate_amount("portal to another dimension", 8) # Every generation has batch size 4 -> amount % 4 == 0 works best
9096
```
9197

9298
```
93-
✔️ Task created with ID: task-lm0V4nZasgAFasd7AsStE67 and PROMPT: portal to another dimension OVERALL: 1/ 2
99+
✔️ Task created with ID: task-lm0V4nZasgAFasd7AsStE67
100+
⌛ Waiting for task to finish...
101+
...task not completed yet
102+
...task not completed yet
103+
...task not completed yet
104+
...task not completed yet
105+
🙌 Task completed!
106+
✔️ Task created with ID: task-WcetZOHt8asdvHb433gi
94107
⌛ Waiting for task to finish...
95-
➕ Appended new generations to all_generations
96-
✔️ Task created with ID: task-WcetZOHt8asdvHb433gi and PROMPT: portal to another dimension OVERALL: 2/ 2
97-
⌛ Waiting for task to finish ..
98-
➕ Appended new generations to all_generations
108+
...task not completed yet
109+
...task not completed yet
110+
...task not completed yet
111+
...task not completed yet
99112
🙌 Task completed!
113+
```
114+
115+
## Generate images from a masked file
116+
DALL·E supports an "inpainting" API that fills-in transparent parts of an image.
117+
The website provides a tool to paint over an existing image to indicate which
118+
parts you want to be transparent. This Python package call assumes that the
119+
image you provide has already been processed to have transparent parts.
100120

121+
```python
122+
# make the right half of a saved image transparent
123+
from PIL import Image, ImageDraw
124+
125+
image = Image.open('my_image.png')
126+
m, n = image.size
127+
128+
area_to_keep = (0, 0, m//2, n)
129+
image_alpha = Image.new("L", image.size, 0)
130+
draw = ImageDraw.Draw(image_alpha)
131+
draw.rectangle(area_to_keep, fill=255)
132+
133+
image_rgba = image.copy()
134+
image_rgba.putalpha(image_alpha)
135+
image_rgba.save('image_with_transparent_right_half.png')
136+
137+
# ask DALL·E to fill-in the transparent right half
138+
generations = dalle.generate_from_masked_image(
139+
"portal to another dimension, digital art",
140+
"image_with_transparent_right_half.png",
141+
)
101142
```
143+
102144
```
103-
-> [list]
145+
✔️ Task created with ID: task-xsuhOthvBXLEjddn3ynyiiOR
146+
⌛ Waiting for task to finish...
147+
...task not completed yet
148+
...task not completed yet
149+
...task not completed yet
150+
...task not completed yet
151+
...task not completed yet
152+
🙌 Task completed!
104153
```
105154

155+
# Try it!
106156
[![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EEgZNAI58V_OiEfRJQSsQV_xkhHzQeRB?usp=sharing)
107-
108-

src/dalle2/dalle2.py

Lines changed: 64 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1+
import base64
12
import json
3+
import math
24
import os
35
import requests
46
import time
57
import urllib
68
import urllib.request
79

10+
from pathlib import Path
11+
812
class Dalle2():
913
def __init__(self, bearer):
1014
self.bearer = bearer
1115
self.batch_size = 4
16+
self.inpainting_batch_size = 3
17+
self.task_sleep_seconds = 3
1218

1319
def generate(self, prompt):
14-
url = "https://labs.openai.com/api/labs/tasks"
15-
headers = {
16-
'Authorization': "Bearer " + self.bearer,
17-
'Content-Type': "application/json",
18-
}
1920
body = {
2021
"task_type": "text2im",
2122
"prompt": {
@@ -24,76 +25,80 @@ def generate(self, prompt):
2425
}
2526
}
2627

28+
return self.get_task_response(body)
29+
30+
def generate_and_download(self, prompt, image_dir=os.getcwd()):
31+
generations = self.generate(prompt)
32+
if not generations:
33+
return None
34+
35+
return self.download(generations, image_dir)
36+
37+
def generate_amount(self, prompt, amount):
38+
if amount < self.batch_size:
39+
raise ValueError(f"passed amount of {amount} cannot be smaller than the batch size of {self.batch_size}")
40+
41+
return [self.generate(prompt) for _ in range(math.ceil(amount / self.batch_size))]
42+
43+
def generate_from_masked_image(self, prompt, image_path):
44+
with open(image_path, "rb") as f:
45+
image_base64 = base64.b64encode(f.read())
46+
47+
body = {
48+
"task_type": "inpainting",
49+
"prompt": {
50+
"caption": prompt,
51+
"batch_size": self.inpainting_batch_size,
52+
"image": image_base64.decode(),
53+
"masked_image": image_base64.decode(), # identical since already masked
54+
}
55+
}
56+
57+
return self.get_task_response(body)
58+
59+
def get_task_response(self, body):
60+
url = "https://labs.openai.com/api/labs/tasks"
61+
headers = {
62+
'Authorization': "Bearer " + self.bearer,
63+
'Content-Type': "application/json",
64+
}
65+
2766
response = requests.post(url, headers=headers, data=json.dumps(body))
2867
if response.status_code != 200:
2968
print(response.text)
3069
return None
3170
data = response.json()
32-
print("✔️ Task created with ID:", data["id"], "and PROMPT:", prompt)
71+
print(f"✔️ Task created with ID: {data['id']}")
3372
print("⌛ Waiting for task to finish...")
3473

3574
while True:
36-
url = "https://labs.openai.com/api/labs/tasks/" + data["id"]
75+
url = f"https://labs.openai.com/api/labs/tasks/{data['id']}"
3776
response = requests.get(url, headers=headers)
3877
data = response.json()
78+
79+
if not response.ok:
80+
print(f"Request failed with status: {response.status_code}, data: {response.json()}")
81+
return None
82+
if data["status"] == "failed":
83+
print(f"Task failed: {data['status_information']}")
84+
return None
3985
if data["status"] == "succeeded":
4086
print("🙌 Task completed!")
41-
generations = data["generations"]["data"]
42-
return generations
87+
return data["generations"]["data"]
4388

44-
time.sleep(3)
89+
print("...task not completed yet")
90+
time.sleep(self.task_sleep_seconds)
4591

46-
def generate_and_download(self, prompt):
47-
generations = self.generate(prompt)
92+
def download(self, generations, image_dir=os.getcwd()):
4893
if not generations:
49-
return None
94+
raise ValueError("generations is empty!")
5095

51-
print("Download to directory: " + os.getcwd())
96+
file_paths = []
5297
for generation in generations:
5398
image_url = generation["generation"]["image_path"]
54-
image_id = generation["id"]
55-
56-
urllib.request.urlretrieve(image_url, image_id +".jpg")
57-
print("✔️ Downloaded: ", image_id + ".jpg")
58-
59-
return generations
60-
61-
def generate_amount(self, prompt, amount):
62-
url = "https://labs.openai.com/api/labs/tasks"
63-
headers = {
64-
'Authorization': "Bearer " + self.bearer,
65-
'Content-Type': "application/json",
66-
}
67-
body = {
68-
"task_type": "text2im",
69-
"prompt": {
70-
"caption": prompt,
71-
"batch_size": self.batch_size,
72-
}
73-
}
99+
file_path = Path(image_dir, generation['id']).with_suffix('.png')
100+
file_paths.append(str(file_path))
101+
urllib.request.urlretrieve(image_url, file_path)
102+
print(f"✔️ Downloaded: {file_path}")
74103

75-
all_generations = []
76-
for i in range(1, int(amount / self.batch_size +1)):
77-
url = "https://labs.openai.com/api/labs/tasks"
78-
response = requests.post(url, headers=headers, data=json.dumps(body))
79-
if response.status_code != 200:
80-
print(response.text)
81-
return None
82-
data = response.json()
83-
print("✔️ Task created with ID:", data["id"], "and PROMPT:", prompt, "OVERALL:", str(i) + "/", int(amount / self.batch_size))
84-
print("⌛ Waiting for task to finish...")
85-
86-
while True:
87-
url = "https://labs.openai.com/api/labs/tasks/" + data["id"]
88-
response = requests.get(url, headers=headers)
89-
data = response.json()
90-
if data["status"] == "succeeded":
91-
generations = data["generations"]["data"]
92-
print("➕ Appended new generations to all_generations")
93-
all_generations.append(generations)
94-
break
95-
96-
time.sleep(3)
97-
print("🙌 Task completed!")
98-
print(all_generations)
99-
return all_generations
104+
return file_paths

0 commit comments

Comments
 (0)