internet_ml/internet_ml/NLP/no_context/QA.py

104 lines
3.4 KiB
Python
Raw Normal View History

2023-01-11 15:59:46 +00:00
# type: ignore
2023-01-14 14:20:23 +00:00
"""
model naming convention
# Open-AI models:
include prefix openai-*
# HuggingFace
include prefix hf-*
"""
2022-12-27 06:38:47 +00:00
from typing import Any, List, Tuple
2022-12-26 15:43:10 +00:00
2022-12-30 05:28:26 +00:00
import os
2022-12-25 17:15:24 +00:00
import sys
from pathlib import Path
2022-12-30 05:28:26 +00:00
import dotenv
2023-01-10 12:50:43 +00:00
import openai
2023-01-11 15:59:46 +00:00
from transformers import pipeline
2022-12-25 17:15:24 +00:00
sys.path.append(str(Path(__file__).parent.parent.parent) + "/tools/NLP/data")
2023-01-11 15:59:46 +00:00
sys.path.append(str(Path(__file__).parent.parent.parent) + "/tools/NLP")
2023-01-12 05:50:18 +00:00
sys.path.append(str(Path(__file__).parent.parent.parent) + "/tools")
2022-12-27 06:38:47 +00:00
sys.path.append(str(Path(__file__).parent.parent.parent) + "/utils")
2023-01-12 05:50:18 +00:00
2022-12-27 06:38:47 +00:00
import config
2022-12-25 17:15:24 +00:00
import internet
2023-01-11 15:59:46 +00:00
from ChatGPT import Chatbot
2022-12-25 17:15:24 +00:00
2023-01-10 12:50:43 +00:00
dotenv.load_dotenv()
2022-12-26 15:43:10 +00:00
2022-12-30 06:50:36 +00:00
def answer(
2023-01-10 12:50:43 +00:00
query: str,
2023-01-11 15:59:46 +00:00
model: str = "openai-chatgpt",
2023-01-10 12:50:43 +00:00
GOOGLE_SEARCH_API_KEY: str = "",
GOOGLE_SEARCH_ENGINE_ID: str = "",
OPENAI_API_KEY: str = "",
CHATGPT_SESSION_TOKEN: str = "",
2023-01-12 05:50:18 +00:00
CHATGPT_CONVERSATION_ID: str = "",
CHATGPT_PARENT_ID: str = "",
2022-12-30 06:50:36 +00:00
) -> tuple[Any, list[str]]:
2023-01-10 12:50:43 +00:00
if OPENAI_API_KEY == "":
OPENAI_API_KEY = str(os.environ.get("OPENAI_API_KEY"))
openai.api_key = OPENAI_API_KEY
if CHATGPT_SESSION_TOKEN == "":
CHATGPT_SESSION_TOKEN = str(os.environ.get("CHATGPT_SESSION_TOKEN"))
2023-01-12 05:50:18 +00:00
if CHATGPT_CONVERSATION_ID == "":
CHATGPT_CONVERSATION_ID = str(os.environ.get("CHATGPT_CONVERSATION_ID"))
if CHATGPT_PARENT_ID == "":
CHATGPT_PARENT_ID = str(os.environ.get("CHATGPT_PARENT_ID"))
2023-01-14 14:20:23 +00:00
2023-01-11 15:59:46 +00:00
if not (model.startswith("openai-") or model.startswith("hf-")):
model = "openai-chatgpt" # Default
2023-01-14 14:20:23 +00:00
results: tuple[list[str], list[str]] = internet.Google(
query, GOOGLE_SEARCH_API_KEY, GOOGLE_SEARCH_ENGINE_ID
).google()
2023-01-11 15:59:46 +00:00
if model.startswith("openai-"):
if model == "openai-chatgpt":
# ChatGPT
2023-01-14 13:12:43 +00:00
prompt = f"Using the context: {' '.join(filter(lambda x: isinstance(x, str), results[0]))[:3000]} and answer the question with the context above and previous knowledge: \"{query}\". Also write long answers or essays if asked."
print(prompt)
2023-01-14 14:20:23 +00:00
exit(1)
2023-01-11 15:59:46 +00:00
chatbot = Chatbot(
{"session_token": CHATGPT_SESSION_TOKEN},
2023-01-14 13:12:43 +00:00
conversation_id=None,
parent_id=None,
2023-01-11 15:59:46 +00:00
)
response = chatbot.ask(
2023-01-12 05:50:18 +00:00
prompt=prompt,
2023-01-14 13:12:43 +00:00
conversation_id=None,
parent_id=None,
2023-01-11 15:59:46 +00:00
)
2023-01-12 04:51:03 +00:00
return (response["message"], results[1])
2023-01-11 15:59:46 +00:00
else:
if model == "openai-text-davinci-003":
2023-01-14 14:20:23 +00:00
# text-davinci-003
prompt = f"Using the context: {' '.join(filter(lambda x: isinstance(x, str), results[0]))[:3000]} and answer the question with the context above and previous knowledge: \"{query}\". Also write long answers or essays if asked."
2023-01-11 15:59:46 +00:00
response = openai.Completion.create(
model="text-davinci-003",
2023-01-14 14:20:23 +00:00
prompt=prompt,
2023-01-11 15:59:46 +00:00
max_tokens=len(context),
n=1,
stop=None,
temperature=0.5,
)
return (response.choices[0].text, results[1])
# TODO: add suport later
2023-01-10 12:50:43 +00:00
else:
model = model.replace("hf-", "", 1)
2023-01-11 15:59:46 +00:00
qa_model = pipeline("question-answering", model=model)
response = qa_model(question=query, context=" ".join(results[0]))
return (response["answer"], results[1])
2022-12-27 06:38:47 +00:00
2023-01-12 10:41:31 +00:00
print(
answer(
2023-01-14 13:12:43 +00:00
query="Best original song in 80th Golden Globe award 2023?",
model="openai-chatgpt",
2023-01-12 10:41:31 +00:00
)
)