Skip to content

Commit bb0aa6e

Browse files
author
xesrc
committed
support for conversation export and import
1 parent d3181b6 commit bb0aa6e

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

chat/views.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,70 @@ def gen_title(request):
179179
})
180180

181181

182+
@api_view(['POST'])
183+
# @authentication_classes([JWTAuthentication])
184+
@permission_classes([IsAuthenticated])
185+
def upload_conversations(request):
186+
"""allow user to import a list of conversations"""
187+
user=request.user
188+
import_err_msg = 'bad_import'
189+
conversation_ids = []
190+
try:
191+
imports = request.data.get('imports')
192+
# verify
193+
conversations = []
194+
for conversation in imports:
195+
topic = conversation.get('conversation_topic', None)
196+
messages = []
197+
for message in conversation.get('messages'):
198+
msg = {}
199+
msg['role'] = message['role']
200+
msg['content'] = message['content']
201+
messages.append(msg)
202+
if len(messages) > 0:
203+
conversations.append({
204+
'topic': topic,
205+
'messages': messages,
206+
})
207+
# dump
208+
for conversation in conversations:
209+
topic = conversation['topic']
210+
messages = conversation['messages']
211+
cobj = Conversation(
212+
topic=topic if topic else '',
213+
user=user,
214+
)
215+
cobj.save()
216+
conversation_ids.append(cobj.id)
217+
for idx, msg in enumerate(messages):
218+
try:
219+
Message._meta.get_field('user')
220+
mobj = Message(
221+
user=user,
222+
conversation=cobj,
223+
message=msg['content'],
224+
is_bot=msg['role'] != 'user',
225+
messages=messages[:idx + 1],
226+
)
227+
except:
228+
mobj = Message(
229+
conversation=cobj,
230+
message=msg['content'],
231+
is_bot=msg['role'] != 'user',
232+
messages=messages[:idx + 1],
233+
)
234+
mobj.save()
235+
except Exception as e:
236+
logger.debug(e)
237+
return Response(
238+
{'error': import_err_msg},
239+
status=status.HTTP_400_BAD_REQUEST
240+
)
241+
242+
# return a list of new conversation id
243+
return Response(conversation_ids)
244+
245+
182246
@api_view(['POST'])
183247
# @authentication_classes([JWTAuthentication])
184248
@permission_classes([IsAuthenticated])

chatgpt_ui_server/urls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
"""
1616
from django.contrib import admin
1717
from django.urls import path, include
18-
from chat.views import conversation, gen_title
18+
from chat.views import conversation, gen_title, upload_conversations
1919

2020
urlpatterns = [
2121
path('api/chat/', include('chat.urls')),
2222
path('api/conversation/', conversation, name='conversation'),
23+
path('api/upload_conversations/', upload_conversations, name='upload_conversations'),
2324
path('api/gen_title/', gen_title, name='gen_title'),
2425
path('api/account/', include('account.urls')),
2526
path('admin/', admin.site.urls),

0 commit comments

Comments
 (0)