2023-01-11 15:59:46 +00:00
# type: ignore
2023-01-14 14:20:23 +00:00
"""
model naming convention
# Open-AI models:
include prefix openai - *
# HuggingFace
include prefix hf - *
"""
2022-12-27 06:38:47 +00:00
from typing import Any , List , Tuple
2022-12-26 15:43:10 +00:00
2022-12-30 05:28:26 +00:00
import os
2022-12-25 17:15:24 +00:00
import sys
from pathlib import Path
sys . path . append ( str ( Path ( __file__ ) . parent . parent . parent ) + " /tools/NLP/data " )
2023-01-11 15:59:46 +00:00
sys . path . append ( str ( Path ( __file__ ) . parent . parent . parent ) + " /tools/NLP " )
2023-01-12 05:50:18 +00:00
sys . path . append ( str ( Path ( __file__ ) . parent . parent . parent ) + " /tools " )
2022-12-27 06:38:47 +00:00
sys . path . append ( str ( Path ( __file__ ) . parent . parent . parent ) + " /utils " )
2023-01-12 05:50:18 +00:00
2022-12-27 06:38:47 +00:00
import config
2023-01-18 14:38:51 +00:00
import dotenv
2022-12-25 17:15:24 +00:00
import internet
2023-01-18 14:38:51 +00:00
import openai
2023-01-11 15:59:46 +00:00
from ChatGPT import Chatbot
2023-01-18 14:38:51 +00:00
from transformers import pipeline
2022-12-25 17:15:24 +00:00
2023-01-10 12:50:43 +00:00
dotenv . load_dotenv ( )
2022-12-26 15:43:10 +00:00
2022-12-30 06:50:36 +00:00
def answer (
2023-01-10 12:50:43 +00:00
query : str ,
2023-01-11 15:59:46 +00:00
model : str = " openai-chatgpt " ,
2023-01-10 12:50:43 +00:00
GOOGLE_SEARCH_API_KEY : str = " " ,
GOOGLE_SEARCH_ENGINE_ID : str = " " ,
OPENAI_API_KEY : str = " " ,
CHATGPT_SESSION_TOKEN : str = " " ,
2022-12-30 06:50:36 +00:00
) - > tuple [ Any , list [ str ] ] :
2023-01-10 12:50:43 +00:00
if OPENAI_API_KEY == " " :
OPENAI_API_KEY = str ( os . environ . get ( " OPENAI_API_KEY " ) )
openai . api_key = OPENAI_API_KEY
if CHATGPT_SESSION_TOKEN == " " :
CHATGPT_SESSION_TOKEN = str ( os . environ . get ( " CHATGPT_SESSION_TOKEN " ) )
2023-01-14 14:20:23 +00:00
2023-01-11 15:59:46 +00:00
if not ( model . startswith ( " openai- " ) or model . startswith ( " hf- " ) ) :
model = " openai-chatgpt " # Default
2023-01-14 14:20:23 +00:00
results : tuple [ list [ str ] , list [ str ] ] = internet . Google (
query , GOOGLE_SEARCH_API_KEY , GOOGLE_SEARCH_ENGINE_ID
) . google ( )
2023-01-18 14:38:51 +00:00
context : str = str ( " " . join ( [ str ( string ) for string in results [ 0 ] ] ) )
print ( f " context: { context } " )
2023-01-14 14:20:23 +00:00
2023-01-11 15:59:46 +00:00
if model . startswith ( " openai- " ) :
if model == " openai-chatgpt " :
# ChatGPT
2023-01-18 14:38:51 +00:00
prompt = f ' Use the context: { context [ : 4000 ] } and answer the question: " { query } " with the context and prior knowledge. Also write at the very least long answers. '
2023-01-11 15:59:46 +00:00
chatbot = Chatbot (
{ " session_token " : CHATGPT_SESSION_TOKEN } ,
2023-01-14 13:12:43 +00:00
conversation_id = None ,
parent_id = None ,
2023-01-11 15:59:46 +00:00
)
response = chatbot . ask (
2023-01-12 05:50:18 +00:00
prompt = prompt ,
2023-01-14 13:12:43 +00:00
conversation_id = None ,
parent_id = None ,
2023-01-11 15:59:46 +00:00
)
2023-01-12 04:51:03 +00:00
return ( response [ " message " ] , results [ 1 ] )
2023-01-11 15:59:46 +00:00
else :
if model == " openai-text-davinci-003 " :
2023-01-14 14:20:23 +00:00
# text-davinci-003
2023-01-18 14:38:51 +00:00
prompt = f ' Use the context: { context [ : 3000 ] } and answer the question: " { query } " with the context and prior knowledge. Also write at the very least long answers. '
2023-01-11 15:59:46 +00:00
response = openai . Completion . create (
model = " text-davinci-003 " ,
2023-01-14 14:20:23 +00:00
prompt = prompt ,
2023-01-11 15:59:46 +00:00
max_tokens = len ( context ) ,
n = 1 ,
stop = None ,
temperature = 0.5 ,
)
return ( response . choices [ 0 ] . text , results [ 1 ] )
# TODO: add suport later
2023-01-10 12:50:43 +00:00
else :
2023-01-18 14:38:51 +00:00
# HuggingFace
2023-01-10 12:50:43 +00:00
model = model . replace ( " hf- " , " " , 1 )
2023-01-11 15:59:46 +00:00
qa_model = pipeline ( " question-answering " , model = model )
2023-01-18 14:38:51 +00:00
response = qa_model ( question = query , context = context )
2023-01-11 15:59:46 +00:00
return ( response [ " answer " ] , results [ 1 ] )
2022-12-27 06:38:47 +00:00
2023-01-12 10:41:31 +00:00
print (
answer (
2023-01-18 14:38:51 +00:00
query = " What is the newest pokemon game? " ,
model = " hf-deepset/xlm-roberta-large-squad2 " ,
2023-01-12 10:41:31 +00:00
)
)