- Added a picture taken to examples
- Corrected few bugs - added a retry loop if analysis did not go well
This commit is contained in:
parent
91e5a97b7d
commit
a021eb3c2f
40
main.py
40
main.py
@ -1,13 +1,39 @@
|
||||
from fastapi import FastAPI
|
||||
import dotenv
|
||||
from fastapi import FastAPI, File, Form, HTTPException
|
||||
from typing import Annotated
|
||||
|
||||
app = FastAPI()
|
||||
from src.model import text_recipe_analysis, img_recipe_analysis, find_replacement
|
||||
|
||||
app = FastAPI(debug=True, title="GoodRecipeAPI", version="1.0")
|
||||
|
||||
key = dotenv.get_key(".env", "MISTRAL_API_KEY")
|
||||
args = {"model": "mistral-small-2506", "api_key": key}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Hello World"}
|
||||
async def home():
|
||||
return {"message": "ok"}
|
||||
|
||||
#TODO Check prompt injection
|
||||
|
||||
@app.post("/analysis/text")
|
||||
async def text_analysis(recipe: Annotated[str, Form()]):
|
||||
try:
|
||||
return text_recipe_analysis(recipe, **args)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the recipe: {e}")
|
||||
|
||||
@app.post("/analysis/img")
|
||||
async def img_analysis(recipe: Annotated[bytes, File()]):
|
||||
try:
|
||||
return img_recipe_analysis(recipe, **args)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the image: {e}")
|
||||
|
||||
|
||||
@app.get("/hello/{name}")
|
||||
async def say_hello(name: str):
|
||||
return {"message": f"Hello {name}"}
|
||||
@app.get("/replacements")
|
||||
async def replace(ingredients: str):
|
||||
if ';' not in ingredients or ingredients[-1] != ';':
|
||||
raise HTTPException(status_code=500, detail=f"Ingredients should be separated by ';' and end with ';'")
|
||||
ingredients = ingredients.split(";")[:-1] # last ingredient is empty
|
||||
|
||||
return find_replacement(ingredients, **args)
|
||||
|
||||
BIN
ressources/imgs/test-recipe.jpg
Normal file
BIN
ressources/imgs/test-recipe.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 158 KiB |
41
src/model.py
41
src/model.py
@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from mistralai import Mistral
|
||||
|
||||
from .utils import create_message, encode_image, response2text
|
||||
from .utils import create_message, encode_image, response2text, test_img
|
||||
|
||||
|
||||
def text_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||
@ -36,18 +36,33 @@ The json must be containing the fields:
|
||||
If the original language of the recipe is not english, translate all the ingredients name and the preparation process to english.
|
||||
Here is the recipe you have to analyze
|
||||
|
||||
If the following text is not a recipe return a json object with the field "error" set to boolean value True.
|
||||
|
||||
RECIPE:
|
||||
{recipe}
|
||||
"""
|
||||
|
||||
client = Mistral(api_key=kwargs["api_key"])
|
||||
return json.loads(client.chat.complete(
|
||||
|
||||
answer = json.loads(client.chat.complete(
|
||||
model=kwargs["model"],
|
||||
messages=[create_message("user", task)],
|
||||
response_format={ "type": "json_object" }
|
||||
).choices[0].message.content)
|
||||
|
||||
def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||
while "ingredients" not in answer or "preparation" not in answer:
|
||||
print("Trying generation again")
|
||||
answer = json.loads(client.chat.complete(
|
||||
model=kwargs["model"],
|
||||
messages=[create_message("user", task)],
|
||||
response_format={"type": "json_object"}
|
||||
).choices[0].message.content)
|
||||
|
||||
if "error" in answer:
|
||||
raise ValueError("The target text is not a recipe")
|
||||
return answer
|
||||
|
||||
def img_recipe_analysis(recipe: bytes, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Analyze a recipe in image returning ingredients and preparation steps
|
||||
Call a MistralAI model with parameters in kwargs
|
||||
@ -55,6 +70,8 @@ def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||
:param kwargs:
|
||||
:return: dict containing ingredients and preparation method
|
||||
"""
|
||||
if not test_img(recipe):
|
||||
raise ValueError("Provided bytes are not an image file")
|
||||
|
||||
img = encode_image(recipe)
|
||||
|
||||
@ -70,8 +87,9 @@ def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||
|
||||
text = "".join([page.markdown for page in ocr_answer])
|
||||
|
||||
with open("test.md", 'w') as f:
|
||||
f.write(text)
|
||||
# Test purposes
|
||||
# with open("test.md", 'w') as f:
|
||||
# f.write(text)
|
||||
|
||||
return text_recipe_analysis(text, **kwargs)
|
||||
|
||||
@ -83,7 +101,6 @@ def find_replacement(ingredients: list[str], **kwargs) -> dict[str, Any]:
|
||||
:return: a dict with a list of replacements and indications about these replacements
|
||||
"""
|
||||
client = Mistral(api_key=kwargs["api_key"])
|
||||
|
||||
agent = client.beta.agents.create(
|
||||
name="GoodRecipe Agent",
|
||||
description="Agent to retrieve ingredient replacements with web search",
|
||||
@ -140,12 +157,18 @@ TEXT:
|
||||
|
||||
""" + response
|
||||
|
||||
answer = client.chat.complete(
|
||||
answer = json.loads(client.chat.complete(
|
||||
model=kwargs["model"],
|
||||
messages=[create_message("user", json_task)],
|
||||
response_format={"type": "json_object"}
|
||||
).choices[0].message.content
|
||||
).choices[0].message.content)
|
||||
|
||||
replacements[ingredient] = json.loads(answer)['substitutes']
|
||||
if type(answer) == list:
|
||||
replacements[ingredient] = answer
|
||||
elif type(answer) == dict and "substitutes" in answer:
|
||||
replacements[ingredient] = answer['substitutes']
|
||||
else:
|
||||
# For now
|
||||
replacements[ingredient] = []
|
||||
|
||||
return replacements
|
||||
27
src/utils.py
27
src/utils.py
@ -4,7 +4,9 @@
|
||||
:brief: Utility functions
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from mistralai import ConversationResponse, TextChunk
|
||||
|
||||
|
||||
@ -18,18 +20,21 @@ def create_message(role: str, content: str) -> dict[str, str]:
|
||||
return {"role": role, "content": content}
|
||||
|
||||
# from https://docs.mistral.ai/capabilities/vision/
|
||||
def encode_image(image_path):
|
||||
def encode_image(image_bytes: bytes) -> str:
|
||||
"""Encode the image to base64."""
|
||||
try:
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file {image_path} was not found.")
|
||||
return None
|
||||
except Exception as e: # Added general exception handling
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
return base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
def response2text(response: ConversationResponse) -> str:
|
||||
text = [output.text for output in response.outputs[1].content if type(output) is TextChunk]
|
||||
return "".join(text)
|
||||
return "".join(text)
|
||||
|
||||
def test_img(img: bytes):
|
||||
"""
|
||||
Test if the bytes provided are a valid image or not
|
||||
"""
|
||||
try:
|
||||
image = Image.open(io.BytesIO(img))
|
||||
image.verify()
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
Loading…
x
Reference in New Issue
Block a user