Compare commits
No commits in common. "a021eb3c2fb250c74e91d3285dccb2081a3a8a01" and "75bd6ba48a939dbd9d9eca87f425168b4b2688ff" have entirely different histories.
a021eb3c2f
...
75bd6ba48a
40
main.py
40
main.py
@ -1,39 +1,13 @@
|
|||||||
import dotenv
|
from fastapi import FastAPI
|
||||||
from fastapi import FastAPI, File, Form, HTTPException
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from src.model import text_recipe_analysis, img_recipe_analysis, find_replacement
|
app = FastAPI()
|
||||||
|
|
||||||
app = FastAPI(debug=True, title="GoodRecipeAPI", version="1.0")
|
|
||||||
|
|
||||||
key = dotenv.get_key(".env", "MISTRAL_API_KEY")
|
|
||||||
args = {"model": "mistral-small-2506", "api_key": key}
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def home():
|
async def root():
|
||||||
return {"message": "ok"}
|
return {"message": "Hello World"}
|
||||||
|
|
||||||
#TODO Check prompt injection
|
|
||||||
|
|
||||||
@app.post("/analysis/text")
|
|
||||||
async def text_analysis(recipe: Annotated[str, Form()]):
|
|
||||||
try:
|
|
||||||
return text_recipe_analysis(recipe, **args)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the recipe: {e}")
|
|
||||||
|
|
||||||
@app.post("/analysis/img")
|
|
||||||
async def img_analysis(recipe: Annotated[bytes, File()]):
|
|
||||||
try:
|
|
||||||
return img_recipe_analysis(recipe, **args)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"An error occured trying to analyze the image: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/replacements")
|
@app.get("/hello/{name}")
|
||||||
async def replace(ingredients: str):
|
async def say_hello(name: str):
|
||||||
if ';' not in ingredients or ingredients[-1] != ';':
|
return {"message": f"Hello {name}"}
|
||||||
raise HTTPException(status_code=500, detail=f"Ingredients should be separated by ';' and end with ';'")
|
|
||||||
ingredients = ingredients.split(";")[:-1] # last ingredient is empty
|
|
||||||
|
|
||||||
return find_replacement(ingredients, **args)
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
fastapi[standard]
|
fastapi
|
||||||
google
|
google
|
||||||
requests
|
requests
|
||||||
dotenv
|
dotenv
|
||||||
mistralai
|
mistralai
|
||||||
pillow
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 158 KiB |
41
src/model.py
41
src/model.py
@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from mistralai import Mistral
|
from mistralai import Mistral
|
||||||
|
|
||||||
from .utils import create_message, encode_image, response2text, test_img
|
from .utils import create_message, encode_image, response2text
|
||||||
|
|
||||||
|
|
||||||
def text_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
def text_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||||
@ -36,33 +36,18 @@ The json must be containing the fields:
|
|||||||
If the original language of the recipe is not english, translate all the ingredients name and the preparation process to english.
|
If the original language of the recipe is not english, translate all the ingredients name and the preparation process to english.
|
||||||
Here is the recipe you have to analyze
|
Here is the recipe you have to analyze
|
||||||
|
|
||||||
If the following text is not a recipe return a json object with the field "error" set to boolean value True.
|
|
||||||
|
|
||||||
RECIPE:
|
RECIPE:
|
||||||
{recipe}
|
{recipe}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client = Mistral(api_key=kwargs["api_key"])
|
client = Mistral(api_key=kwargs["api_key"])
|
||||||
|
return json.loads(client.chat.complete(
|
||||||
answer = json.loads(client.chat.complete(
|
|
||||||
model=kwargs["model"],
|
model=kwargs["model"],
|
||||||
messages=[create_message("user", task)],
|
messages=[create_message("user", task)],
|
||||||
response_format={ "type": "json_object" }
|
response_format={ "type": "json_object" }
|
||||||
).choices[0].message.content)
|
).choices[0].message.content)
|
||||||
|
|
||||||
while "ingredients" not in answer or "preparation" not in answer:
|
def img_recipe_analysis(recipe: str, **kwargs) -> dict[str, Any]:
|
||||||
print("Trying generation again")
|
|
||||||
answer = json.loads(client.chat.complete(
|
|
||||||
model=kwargs["model"],
|
|
||||||
messages=[create_message("user", task)],
|
|
||||||
response_format={"type": "json_object"}
|
|
||||||
).choices[0].message.content)
|
|
||||||
|
|
||||||
if "error" in answer:
|
|
||||||
raise ValueError("The target text is not a recipe")
|
|
||||||
return answer
|
|
||||||
|
|
||||||
def img_recipe_analysis(recipe: bytes, **kwargs) -> dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
Analyze a recipe in image returning ingredients and preparation steps
|
Analyze a recipe in image returning ingredients and preparation steps
|
||||||
Call a MistralAI model with parameters in kwargs
|
Call a MistralAI model with parameters in kwargs
|
||||||
@ -70,8 +55,6 @@ def img_recipe_analysis(recipe: bytes, **kwargs) -> dict[str, Any]:
|
|||||||
:param kwargs:
|
:param kwargs:
|
||||||
:return: dict containing ingredients and preparation method
|
:return: dict containing ingredients and preparation method
|
||||||
"""
|
"""
|
||||||
if not test_img(recipe):
|
|
||||||
raise ValueError("Provided bytes are not an image file")
|
|
||||||
|
|
||||||
img = encode_image(recipe)
|
img = encode_image(recipe)
|
||||||
|
|
||||||
@ -87,9 +70,8 @@ def img_recipe_analysis(recipe: bytes, **kwargs) -> dict[str, Any]:
|
|||||||
|
|
||||||
text = "".join([page.markdown for page in ocr_answer])
|
text = "".join([page.markdown for page in ocr_answer])
|
||||||
|
|
||||||
# Test purposes
|
with open("test.md", 'w') as f:
|
||||||
# with open("test.md", 'w') as f:
|
f.write(text)
|
||||||
# f.write(text)
|
|
||||||
|
|
||||||
return text_recipe_analysis(text, **kwargs)
|
return text_recipe_analysis(text, **kwargs)
|
||||||
|
|
||||||
@ -101,6 +83,7 @@ def find_replacement(ingredients: list[str], **kwargs) -> dict[str, Any]:
|
|||||||
:return: a dict with a list of replacements and indications about these replacements
|
:return: a dict with a list of replacements and indications about these replacements
|
||||||
"""
|
"""
|
||||||
client = Mistral(api_key=kwargs["api_key"])
|
client = Mistral(api_key=kwargs["api_key"])
|
||||||
|
|
||||||
agent = client.beta.agents.create(
|
agent = client.beta.agents.create(
|
||||||
name="GoodRecipe Agent",
|
name="GoodRecipe Agent",
|
||||||
description="Agent to retrieve ingredient replacements with web search",
|
description="Agent to retrieve ingredient replacements with web search",
|
||||||
@ -157,18 +140,12 @@ TEXT:
|
|||||||
|
|
||||||
""" + response
|
""" + response
|
||||||
|
|
||||||
answer = json.loads(client.chat.complete(
|
answer = client.chat.complete(
|
||||||
model=kwargs["model"],
|
model=kwargs["model"],
|
||||||
messages=[create_message("user", json_task)],
|
messages=[create_message("user", json_task)],
|
||||||
response_format={"type": "json_object"}
|
response_format={"type": "json_object"}
|
||||||
).choices[0].message.content)
|
).choices[0].message.content
|
||||||
|
|
||||||
if type(answer) == list:
|
replacements[ingredient] = json.loads(answer)['substitutes']
|
||||||
replacements[ingredient] = answer
|
|
||||||
elif type(answer) == dict and "substitutes" in answer:
|
|
||||||
replacements[ingredient] = answer['substitutes']
|
|
||||||
else:
|
|
||||||
# For now
|
|
||||||
replacements[ingredient] = []
|
|
||||||
|
|
||||||
return replacements
|
return replacements
|
||||||
27
src/utils.py
27
src/utils.py
@ -4,9 +4,7 @@
|
|||||||
:brief: Utility functions
|
:brief: Utility functions
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
import io
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from mistralai import ConversationResponse, TextChunk
|
from mistralai import ConversationResponse, TextChunk
|
||||||
|
|
||||||
|
|
||||||
@ -20,21 +18,18 @@ def create_message(role: str, content: str) -> dict[str, str]:
|
|||||||
return {"role": role, "content": content}
|
return {"role": role, "content": content}
|
||||||
|
|
||||||
# from https://docs.mistral.ai/capabilities/vision/
|
# from https://docs.mistral.ai/capabilities/vision/
|
||||||
def encode_image(image_bytes: bytes) -> str:
|
def encode_image(image_path):
|
||||||
"""Encode the image to base64."""
|
"""Encode the image to base64."""
|
||||||
return base64.b64encode(image_bytes).decode('utf-8')
|
try:
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: The file {image_path} was not found.")
|
||||||
|
return None
|
||||||
|
except Exception as e: # Added general exception handling
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def response2text(response: ConversationResponse) -> str:
|
def response2text(response: ConversationResponse) -> str:
|
||||||
text = [output.text for output in response.outputs[1].content if type(output) is TextChunk]
|
text = [output.text for output in response.outputs[1].content if type(output) is TextChunk]
|
||||||
return "".join(text)
|
return "".join(text)
|
||||||
|
|
||||||
def test_img(img: bytes):
|
|
||||||
"""
|
|
||||||
Test if the bytes provided are a valid image or not
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
image = Image.open(io.BytesIO(img))
|
|
||||||
image.verify()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
return False
|
|
||||||
Loading…
x
Reference in New Issue
Block a user