diff --git a/app.py b/app.py index a10d3ec..a37e1ff 100644 --- a/app.py +++ b/app.py @@ -2,8 +2,9 @@ import os import io import base64 from dotenv import load_dotenv -from fastapi import FastAPI, Form +from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel from PIL import Image import moondream as md @@ -17,44 +18,47 @@ model = md.vl(api_key=api_key) # FastAPI app app = FastAPI() +# Utility to decode base64 def decode_base64_image(base64_str: str) -> Image.Image: try: - image_data = base64.b64decode(base64_str.split(",")[-1]) # strip data URL prefix if present + image_data = base64.b64decode(base64_str.split(",")[-1]) return Image.open(io.BytesIO(image_data)) except Exception as e: raise ValueError("Invalid base64 image") from e +# Request schemas +class CaptionRequest(BaseModel): + base64_image: str + length: str = "short" + +class QueryRequest(BaseModel): + base64_image: str + question: str + stream: bool = False + @app.post("/caption") -async def generate_caption( - base64_image: str = Form(...), - length: str = Form("short") -): +async def generate_caption(payload: CaptionRequest): try: - img = decode_base64_image(base64_image) - response = model.caption(img, length=length) - return JSONResponse(content={"caption": response["caption"]}) + img = decode_base64_image(payload.base64_image) + response = model.caption(img, length=payload.length) + return {"caption": response["caption"]} except Exception as e: - return JSONResponse(status_code=500, content={"error": str(e)}) + raise HTTPException(status_code=500, detail=str(e)) @app.post("/query") -async def query_image( - base64_image: str = Form(...), - question: str = Form(...), - stream: bool = Form(False) -): +async def query_image(payload: QueryRequest): try: - img = decode_base64_image(base64_image) + img = decode_base64_image(payload.base64_image) - if stream: + if payload.stream: def generate(): - result = model.query(img, question, stream=True) + result = model.query(img, payload.question, stream=True) for chunk in result["chunk"]: yield chunk - return StreamingResponse(generate(), media_type="text/plain") else: - result = model.query(img, question) - return JSONResponse(content={"answer": result["answer"]}) + result = model.query(img, payload.question) + return {"answer": result["answer"]} except Exception as e: - return JSONResponse(status_code=500, content={"error": str(e)}) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/requirements.txt b/requirements.txt index b36e2a2..c1b2d1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ Pillow python-dotenv fastapi uvicorn -python-multipart \ No newline at end of file +python-multipart +pydantic \ No newline at end of file