From 89198e53c265b4d8f79263d2ecd3ca40608c6c83 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 15 Mar 2023 19:33:13 +0800 Subject: [PATCH 1/2] add gradio support --- inference/gradio_openchatkit.py | 90 +++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 inference/gradio_openchatkit.py diff --git a/inference/gradio_openchatkit.py b/inference/gradio_openchatkit.py new file mode 100644 index 0000000..3269c84 --- /dev/null +++ b/inference/gradio_openchatkit.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : openchat kit gradio application +""" + Run: + # under OpenChatKit/inference from https://github.com/togethercomputer/OpenChatKit + CUDA_VISIBLE_DEVICES=2,3 python3 gradio_openchatkit.py + Warn: + the bigger max_new_tokens the more cuda mem, so be careful +""" + +import os +import sys + +CUR_DIR = os.path.abspath(os.path.dirname(__file__)) +MODEL_PATH = os.path.join(CUR_DIR, "../../GPT-NeoXT-Chat-Base-20B/") + +sys.path.append(CUR_DIR) + +from loguru import logger +import gradio as gr +import argparse +from transformers import AutoTokenizer, AutoModelForCausalLM + +from bot import ChatModel + + +class ConvChat(object): + """ + Conversation Chat + """ + def __init__(self, + model_name: str, + max_new_tokens: int = 256, + sample: bool = False, + temperature: int = 0.6, + top_k: int = 40): + self.max_new_tokens = max_new_tokens + self.sample = sample + self.temperature = temperature + self.top_k = top_k + + logger.info("Start to init Chat Model") + self.chat_model = ChatModel(model_name=model_name, gpu_id=0) + logger.info("Initialized Chat Model") + + def run_text(self, input_text: gr.Textbox, state: gr.State): + + response = self.chat_model.do_inference( + prompt=input_text, + max_new_tokens=self.max_new_tokens, + do_sample=self.sample, + temperature=self.temperature, + top_k=self.top_k + ) + + state = state + [(input_text, response)] + return state, state + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default=MODEL_PATH, help="model huggingface repo name or local path") + parser.add_argument("--server_port", type=int, default=7800, help="gradio server port") + + args = parser.parse_args() + + conv_chat = ConvChat(model_name=args.model_name) + + with gr.Blocks(css="OpenChatKit .overflow-y-auto{height:500px}") as gr_chat: + chatbot = gr.Chatbot(elem_id="chatbot", label="OpenChatKit") + state = gr.State([]) + + with gr.Row(): + with gr.Column(scale=0.8): + input_text = gr.Textbox(show_label=False, + placeholder="Enter your question").style(container=False) + with gr.Column(scale=0.2, min_width=0): + clear_btn = gr.Button("Clear") + + input_text.submit(conv_chat.run_text, [input_text, state], [chatbot, state]) + input_text.submit(lambda: "", None, input_text) + + clear_btn.click(lambda: [], None, chatbot) + clear_btn.click(lambda: [], None, state) + + gr_chat.launch( + server_name="0.0.0.0", + server_port=args.server_port + ) From d9ea1b3ef1a7133561d5586ddda5a1e4e37a4566 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 31 Mar 2023 11:03:37 +0800 Subject: [PATCH 2/2] update chat prompt template and model path --- inference/gradio_openchatkit.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/inference/gradio_openchatkit.py b/inference/gradio_openchatkit.py index 3269c84..15381f8 100644 --- a/inference/gradio_openchatkit.py +++ b/inference/gradio_openchatkit.py @@ -13,7 +13,7 @@ import sys CUR_DIR = os.path.abspath(os.path.dirname(__file__)) -MODEL_PATH = os.path.join(CUR_DIR, "../../GPT-NeoXT-Chat-Base-20B/") +MODEL_PATH = os.path.join(CUR_DIR, "../huggingface_models/GPT-NeoXT-Chat-Base-20B/") sys.path.append(CUR_DIR) @@ -23,6 +23,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from bot import ChatModel +from conversation import Conversation class ConvChat(object): @@ -42,17 +43,22 @@ def __init__(self, logger.info("Start to init Chat Model") self.chat_model = ChatModel(model_name=model_name, gpu_id=0) + + self.conv = Conversation(self.chat_model.human_id, self.chat_model.bot_id) logger.info("Initialized Chat Model") def run_text(self, input_text: gr.Textbox, state: gr.State): + self.conv.push_human_turn(input_text) - response = self.chat_model.do_inference( - prompt=input_text, + output = self.chat_model.do_inference( + prompt=self.conv.get_raw_prompt(), max_new_tokens=self.max_new_tokens, do_sample=self.sample, temperature=self.temperature, top_k=self.top_k ) + self.conv.push_model_response(output) + response = self.conv.get_last_turn() state = state + [(input_text, response)] return state, state