- Added a picture taken to examples
- Corrected few bugs - added a retry loop if analysis did not go well
This commit is contained in:
parent
91e5a97b7d
commit
a021eb3c2f
40
main.py
40
main.py
@ -1,13 +1,39 @@
|
|||||||
from fastapi import FastAPI
|
import dotenv
|
||||||
|
from fastapi import FastAPI, File, Form, HTTPException
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
app = FastAPI()
|
from src.model import text_recipe_analysis, img_recipe_analysis, find_replacement
|
||||||
|
|
||||||
|
app = FastAPI(debug=True, title="GoodRecipeAPI", version="1.0")
|
||||||
|
|
||||||
|
key = dotenv.get_key(".env", "MISTRAL_API_KEY")
|
||||||
|
args = {"model": "mistral-small-2506", "api_key": key}
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def home():
|
||||||
return {"message": "Hello World"}
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
#TODO Check prompt injection
|
||||||
|
|
||||||
|
@app.post("/analysis/text")
|
||||||
|
async def text_analysis(recipe: Annotated[str, Form()]):
|
||||||
|
try:
|
||||||
|
return text_recipe_analysis(recipe, **args)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the recipe: {e}")
|
||||||
|
|
||||||
|
@app.post("/analysis/img")
|
||||||
|
async def img_analysis(recipe: Annotated[bytes, File()]):
|
||||||
|
try:
|
||||||
|
return img_recipe_analysis(recipe, **args)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the image: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/hello/{name}")
|
@app.get("/replacements")
|
||||||
async def say_hello(name: str):
|
async def replace(ingredients: str):
|
||||||
return {"message": f"Hello {name}"}
|
if ';' not in ingredients or ingredients[-1] != ';':
|
||||||
|
raise HTTPException(status_code=500, detail=f"Ingredients should be separated by ';' and end with ';'")
|
||||||
|
ingredients = ingredients.split(";")[:-1] # last ingredient is empty
|
||||||
|
|
||||||
|
return find_replacement(ingredients, **args)
|
||||||
|
|||||||
BIN
ressources/imgs/test-recipe.jpg
Normal file
BIN
ressources/imgs/test-recipe.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 158 KiB |
41
src/model.py
41
src/model.py
@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from mistralai import Mistral
|
from mistralai import Mistral
|
||||||
|
|
||||||
from .utils import create_message, encode_image, response2text
|
from .utils import create_message, encode_image, response2text, test_img
|
||||||
|
|
||||||
|
|
||||||
def text_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
def text_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||||
@ -36,18 +36,33 @@ The json must be containing the fields:
|
|||||||
If the original language of the recipe is not english, translate all the ingredients name and the preparation process to english.
|
If the original language of the recipe is not english, translate all the ingredients name and the preparation process to english.
|
||||||
Here is the recipe you have to analyze
|
Here is the recipe you have to analyze
|
||||||
|
|
||||||
|
If the following text is not a recipe return a json object with the field "error" set to boolean value True.
|
||||||
|
|
||||||
RECIPE:
|
RECIPE:
|
||||||
{recipe}
|
{recipe}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client = Mistral(api_key=kwargs["api_key"])
|
client = Mistral(api_key=kwargs["api_key"])
|
||||||
return json.loads(client.chat.complete(
|
|
||||||
|
answer = json.loads(client.chat.complete(
|
||||||
model=kwargs["model"],
|
model=kwargs["model"],
|
||||||
messages=[create_message("user", task)],
|
messages=[create_message("user", task)],
|
||||||
response_format={ "type": "json_object" }
|
response_format={ "type": "json_object" }
|
||||||
).choices[0].message.content)
|
).choices[0].message.content)
|
||||||
|
|
||||||
def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
while "ingredients" not in answer or "preparation" not in answer:
|
||||||
|
print("Trying generation again")
|
||||||
|
answer = json.loads(client.chat.complete(
|
||||||
|
model=kwargs["model"],
|
||||||
|
messages=[create_message("user", task)],
|
||||||
|
response_format={"type": "json_object"}
|
||||||
|
).choices[0].message.content)
|
||||||
|
|
||||||
|
if "error" in answer:
|
||||||
|
raise ValueError("The target text is not a recipe")
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def img_recipe_analysis(recipe: bytes, **kwargs) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Analyze a recipe in image returning ingredients and preparation steps
|
Analyze a recipe in image returning ingredients and preparation steps
|
||||||
Call a MistralAI model with parameters in kwargs
|
Call a MistralAI model with parameters in kwargs
|
||||||
@ -55,6 +70,8 @@ def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
|||||||
:param kwargs:
|
:param kwargs:
|
||||||
:return: dict containing ingredients and preparation method
|
:return: dict containing ingredients and preparation method
|
||||||
"""
|
"""
|
||||||
|
if not test_img(recipe):
|
||||||
|
raise ValueError("Provided bytes are not an image file")
|
||||||
|
|
||||||
img = encode_image(recipe)
|
img = encode_image(recipe)
|
||||||
|
|
||||||
@ -70,8 +87,9 @@ def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
|||||||
|
|
||||||
text = "".join([page.markdown for page in ocr_answer])
|
text = "".join([page.markdown for page in ocr_answer])
|
||||||
|
|
||||||
with open("test.md", 'w') as f:
|
# Test purposes
|
||||||
f.write(text)
|
# with open("test.md", 'w') as f:
|
||||||
|
# f.write(text)
|
||||||
|
|
||||||
return text_recipe_analysis(text, **kwargs)
|
return text_recipe_analysis(text, **kwargs)
|
||||||
|
|
||||||
@ -83,7 +101,6 @@ def find_replacement(ingredients: list[str], **kwargs) -> dict[str, Any]:
|
|||||||
:return: a dict with a list of replacements and indications about these replacements
|
:return: a dict with a list of replacements and indications about these replacements
|
||||||
"""
|
"""
|
||||||
client = Mistral(api_key=kwargs["api_key"])
|
client = Mistral(api_key=kwargs["api_key"])
|
||||||
|
|
||||||
agent = client.beta.agents.create(
|
agent = client.beta.agents.create(
|
||||||
name="GoodRecipe Agent",
|
name="GoodRecipe Agent",
|
||||||
description="Agent to retrieve ingredient replacements with web search",
|
description="Agent to retrieve ingredient replacements with web search",
|
||||||
@ -140,12 +157,18 @@ TEXT:
|
|||||||
|
|
||||||
""" + response
|
""" + response
|
||||||
|
|
||||||
answer = client.chat.complete(
|
answer = json.loads(client.chat.complete(
|
||||||
model=kwargs["model"],
|
model=kwargs["model"],
|
||||||
messages=[create_message("user", json_task)],
|
messages=[create_message("user", json_task)],
|
||||||
response_format={"type": "json_object"}
|
response_format={"type": "json_object"}
|
||||||
).choices[0].message.content
|
).choices[0].message.content)
|
||||||
|
|
||||||
replacements[ingredient] = json.loads(answer)['substitutes']
|
if type(answer) == list:
|
||||||
|
replacements[ingredient] = answer
|
||||||
|
elif type(answer) == dict and "substitutes" in answer:
|
||||||
|
replacements[ingredient] = answer['substitutes']
|
||||||
|
else:
|
||||||
|
# For now
|
||||||
|
replacements[ingredient] = []
|
||||||
|
|
||||||
return replacements
|
return replacements
|
||||||
27
src/utils.py
27
src/utils.py
@ -4,7 +4,9 @@
|
|||||||
:brief: Utility functions
|
:brief: Utility functions
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
from mistralai import ConversationResponse, TextChunk
|
from mistralai import ConversationResponse, TextChunk
|
||||||
|
|
||||||
|
|
||||||
@ -18,18 +20,21 @@ def create_message(role: str, content: str) -> dict[str, str]:
|
|||||||
return {"role": role, "content": content}
|
return {"role": role, "content": content}
|
||||||
|
|
||||||
# from https://docs.mistral.ai/capabilities/vision/
|
# from https://docs.mistral.ai/capabilities/vision/
|
||||||
def encode_image(image_path):
|
def encode_image(image_bytes: bytes) -> str:
|
||||||
"""Encode the image to base64."""
|
"""Encode the image to base64."""
|
||||||
try:
|
return base64.b64encode(image_bytes).decode('utf-8')
|
||||||
with open(image_path, "rb") as image_file:
|
|
||||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: The file {image_path} was not found.")
|
|
||||||
return None
|
|
||||||
except Exception as e: # Added general exception handling
|
|
||||||
print(f"Error: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def response2text(response: ConversationResponse) -> str:
|
def response2text(response: ConversationResponse) -> str:
|
||||||
text = [output.text for output in response.outputs[1].content if type(output) is TextChunk]
|
text = [output.text for output in response.outputs[1].content if type(output) is TextChunk]
|
||||||
return "".join(text)
|
return "".join(text)
|
||||||
|
|
||||||
|
def test_img(img: bytes):
|
||||||
|
"""
|
||||||
|
Test if the bytes provided are a valid image or not
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = Image.open(io.BytesIO(img))
|
||||||
|
image.verify()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
Loading…
x
Reference in New Issue
Block a user