good qa model
parent
fe61b63a0b
commit
a7b10f646a
|
@ -44,9 +44,11 @@ def answer(
|
||||||
if not (model.startswith("openai-") or model.startswith("hf-")):
|
if not (model.startswith("openai-") or model.startswith("hf-")):
|
||||||
model = "openai-chatgpt" # Default
|
model = "openai-chatgpt" # Default
|
||||||
|
|
||||||
|
print("Scraping the Internet")
|
||||||
results: tuple[list[str], list[str]] = internet.Google(
|
results: tuple[list[str], list[str]] = internet.Google(
|
||||||
query, GOOGLE_SEARCH_API_KEY, GOOGLE_SEARCH_ENGINE_ID
|
query, GOOGLE_SEARCH_API_KEY, GOOGLE_SEARCH_ENGINE_ID
|
||||||
).google()
|
).google()
|
||||||
|
print("Done scraping the Internet")
|
||||||
context: str = str(" ".join([str(string) for string in results]))
|
context: str = str(" ".join([str(string) for string in results]))
|
||||||
print(f"context: {context}")
|
print(f"context: {context}")
|
||||||
|
|
||||||
|
@ -85,11 +87,3 @@ def answer(
|
||||||
qa_model = pipeline("question-answering", model=model)
|
qa_model = pipeline("question-answering", model=model)
|
||||||
response = qa_model(question=query, context=context)
|
response = qa_model(question=query, context=context)
|
||||||
return (response["answer"], results[1])
|
return (response["answer"], results[1])
|
||||||
|
|
||||||
|
|
||||||
print(
|
|
||||||
answer(
|
|
||||||
query="What is the newest pokemon game?",
|
|
||||||
model="openai-chatgpt",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
from QA import answer
|
||||||
|
|
||||||
|
print(
|
||||||
|
answer(
|
||||||
|
query="When was the last cricket worldcup held?",
|
||||||
|
model="hf-deepset/roberta-large-squad2",
|
||||||
|
)
|
||||||
|
)
|
|
@ -1,3 +1,4 @@
|
||||||
|
# type: ignore
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -47,11 +48,7 @@ class Google:
|
||||||
self.__GOOGLE_SEARCH_ENGINE_ID = str(
|
self.__GOOGLE_SEARCH_ENGINE_ID = str(
|
||||||
os.environ.get("GOOGLE_SEARCH_ENGINE_ID")
|
os.environ.get("GOOGLE_SEARCH_ENGINE_ID")
|
||||||
)
|
)
|
||||||
self.__num_res: int = (
|
self.__num_res: int = 10
|
||||||
5
|
|
||||||
if config.NLP_CONF_MODE == "speed"
|
|
||||||
else (20 if config.NLP_CONF_MODE else 10)
|
|
||||||
)
|
|
||||||
self.__query = query
|
self.__query = query
|
||||||
self.__URL_EXTRACTOR: URLExtract = URLExtract()
|
self.__URL_EXTRACTOR: URLExtract = URLExtract()
|
||||||
self.__urls: list[str] = self.__URL_EXTRACTOR.find_urls(query)
|
self.__urls: list[str] = self.__URL_EXTRACTOR.find_urls(query)
|
||||||
|
@ -136,7 +133,7 @@ class Google:
|
||||||
self.__get_urls_contents()
|
self.__get_urls_contents()
|
||||||
if filter_irrelevant:
|
if filter_irrelevant:
|
||||||
self.__filter_irrelevant_processing()
|
self.__filter_irrelevant_processing()
|
||||||
results: tuple[list[str], list[str]] = (self.__content, self.__urls) # type: ignore
|
results: tuple[list[str], list[str]] = (self.__content, self.__urls)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue