internet_ml/internet_ml/NLP/no_context/QA.py

45 lines
1.1 KiB
Python
Raw Normal View History

2022-12-27 06:38:47 +00:00
from typing import Any, List, Tuple
2022-12-26 15:43:10 +00:00
2022-12-27 06:38:47 +00:00
import logging
2022-12-30 05:28:26 +00:00
import os
2022-12-25 17:15:24 +00:00
import sys
from pathlib import Path
2022-12-30 05:28:26 +00:00
import dotenv
2022-12-25 17:15:24 +00:00
from transformers import pipeline
2022-12-30 05:28:26 +00:00
dotenv.load_dotenv()
2022-12-27 06:38:47 +00:00
logging.basicConfig(
filename="QA.log",
filemode="w",
level=logging.INFO,
format="%(name)s - %(levelname)s - %(message)s",
)
2022-12-25 17:15:24 +00:00
sys.path.append(str(Path(__file__).parent.parent.parent) + "/tools/NLP/data")
2022-12-27 06:38:47 +00:00
sys.path.append(str(Path(__file__).parent.parent.parent) + "/utils")
import config
2022-12-25 17:15:24 +00:00
import internet
2022-12-26 15:43:10 +00:00
2022-12-27 12:19:01 +00:00
def answer(query: str) -> tuple[Any, list[str]]:
2022-12-30 05:40:40 +00:00
QA_MODEL: Any = pipeline("question-answering")
GOOGLE_SEARCH_API_KEY = str(os.environ["INTERNET_ML_GOOGLE_API"])
GOOGLE_SEARCH_ENGINE_ID = str(os.environ["INTERNET_ML_GOOGLE_SEARCH_ENGINE_ID"])
2022-12-30 05:28:26 +00:00
results: tuple[list[str], list[str]] = internet.Google(
query, GOOGLE_SEARCH_API_KEY, GOOGLE_SEARCH_ENGINE_ID
).google()
2022-12-27 06:38:47 +00:00
answer: tuple[Any, list[str]] = (
QA_MODEL(question=query, context=str(results[0])),
results[1],
)
if config.CONF_DEBUG:
logging.info(f"Answer: {answer}")
return answer
2022-12-27 13:40:36 +00:00
# print(answer("Who is the author of TinTin?"))
2022-12-27 12:19:01 +00:00
2022-12-27 06:38:47 +00:00
# def custom_answer