diff --git a/main.py b/main.py index 61cb0a6..69d6478 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,39 @@ -from fastapi import FastAPI +import dotenv +from fastapi import FastAPI, File, Form, HTTPException +from typing import Annotated -app = FastAPI() +from src.model import text_recipe_analysis, img_recipe_analysis, find_replacement +app = FastAPI(debug=True, title="GoodRecipeAPI", version="1.0") + +key = dotenv.get_key(".env", "MISTRAL_API_KEY") +args = {"model": "mistral-small-2506", "api_key": key} @app.get("/") -async def root(): - return {"message": "Hello World"} +async def home(): + return {"message": "ok"} + +#TODO Check prompt injection + +@app.post("/analysis/text") +async def text_analysis(recipe: Annotated[str, Form()]): + try: + return text_recipe_analysis(recipe, **args) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the recipe: {e}") + +@app.post("/analysis/img") +async def img_analysis(recipe: Annotated[bytes, File()]): + try: + return img_recipe_analysis(recipe, **args) + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the image: {e}") -@app.get("/hello/{name}") -async def say_hello(name: str): - return {"message": f"Hello {name}"} +@app.get("/replacements") +async def replace(ingredients: str): + if ';' not in ingredients or ingredients[-1] != ';': + raise HTTPException(status_code=500, detail=f"Ingredients should be separated by ';' and end with ';'") + ingredients = ingredients.split(";")[:-1] # last ingredient is empty + + return find_replacement(ingredients, **args) diff --git a/ressources/imgs/test-recipe.jpg b/ressources/imgs/test-recipe.jpg new file mode 100644 index 0000000..f5b50ab Binary files /dev/null and b/ressources/imgs/test-recipe.jpg differ diff --git a/src/model.py b/src/model.py index 2d5a240..d6b6e03 100644 --- a/src/model.py +++ b/src/model.py @@ -10,7 +10,7 @@ from typing import Any from mistralai import Mistral -from .utils import create_message, encode_image, response2text +from .utils import create_message, encode_image, response2text, test_img def text_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]: @@ -36,18 +36,33 @@ The json must be containing the fields: If the original language of the recipe is not english, translate all the ingredients name and the preparation process to english. Here is the recipe you have to analyze +If the following text is not a recipe return a json object with the field "error" set to boolean value True. + RECIPE: {recipe} """ client = Mistral(api_key=kwargs["api_key"]) - return json.loads(client.chat.complete( + + answer = json.loads(client.chat.complete( model=kwargs["model"], messages=[create_message("user", task)], response_format={ "type": "json_object" } ).choices[0].message.content) -def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]: + while "ingredients" not in answer or "preparation" not in answer: + print("Trying generation again") + answer = json.loads(client.chat.complete( + model=kwargs["model"], + messages=[create_message("user", task)], + response_format={"type": "json_object"} + ).choices[0].message.content) + + if "error" in answer: + raise ValueError("The target text is not a recipe") + return answer + +def img_recipe_analysis(recipe: bytes, **kwargs) -> dict[str, Any]: """ Analyze a recipe in image returning ingredients and preparation steps Call a MistralAI model with parameters in kwargs @@ -55,6 +70,8 @@ def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]: :param kwargs: :return: dict containing ingredients and preparation method """ + if not test_img(recipe): + raise ValueError("Provided bytes are not an image file") img = encode_image(recipe) @@ -70,8 +87,9 @@ def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]: text = "".join([page.markdown for page in ocr_answer]) - with open("test.md", 'w') as f: - f.write(text) + # Test purposes + # with open("test.md", 'w') as f: + # f.write(text) return text_recipe_analysis(text, **kwargs) @@ -83,7 +101,6 @@ def find_replacement(ingredients: list[str], **kwargs) -> dict[str, Any]: :return: a dict with a list of replacements and indications about these replacements """ client = Mistral(api_key=kwargs["api_key"]) - agent = client.beta.agents.create( name="GoodRecipe Agent", description="Agent to retrieve ingredient replacements with web search", @@ -140,12 +157,18 @@ TEXT: """ + response - answer = client.chat.complete( + answer = json.loads(client.chat.complete( model=kwargs["model"], messages=[create_message("user", json_task)], response_format={"type": "json_object"} - ).choices[0].message.content + ).choices[0].message.content) - replacements[ingredient] = json.loads(answer)['substitutes'] + if type(answer) == list: + replacements[ingredient] = answer + elif type(answer) == dict and "substitutes" in answer: + replacements[ingredient] = answer['substitutes'] + else: + # For now + replacements[ingredient] = [] return replacements \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 38ecaa1..48368da 100644 --- a/src/utils.py +++ b/src/utils.py @@ -4,7 +4,9 @@ :brief: Utility functions """ import base64 +import io +from PIL import Image from mistralai import ConversationResponse, TextChunk @@ -18,18 +20,21 @@ def create_message(role: str, content: str) -> dict[str, str]: return {"role": role, "content": content} # from https://docs.mistral.ai/capabilities/vision/ -def encode_image(image_path): +def encode_image(image_bytes: bytes) -> str: """Encode the image to base64.""" - try: - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') - except FileNotFoundError: - print(f"Error: The file {image_path} was not found.") - return None - except Exception as e: # Added general exception handling - print(f"Error: {e}") - return None + return base64.b64encode(image_bytes).decode('utf-8') def response2text(response: ConversationResponse) -> str: text = [output.text for output in response.outputs[1].content if type(output) is TextChunk] - return "".join(text) \ No newline at end of file + return "".join(text) + +def test_img(img: bytes): + """ + Test if the bytes provided are a valid image or not + """ + try: + image = Image.open(io.BytesIO(img)) + image.verify() + return True + except Exception as e: + return False \ No newline at end of file