65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
import os
|
|
import io
|
|
import base64
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
from PIL import Image
|
|
import moondream as md
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
api_key = os.getenv("MOON_DREAM_KEY")
|
|
|
|
# Initialize Moondream model
|
|
model = md.vl(api_key=api_key)
|
|
|
|
# FastAPI app
|
|
app = FastAPI()
|
|
|
|
# Utility to decode base64
|
|
def decode_base64_image(base64_str: str) -> Image.Image:
|
|
try:
|
|
image_data = base64.b64decode(base64_str.split(",")[-1])
|
|
return Image.open(io.BytesIO(image_data))
|
|
except Exception as e:
|
|
raise ValueError("Invalid base64 image") from e
|
|
|
|
# Request schemas
|
|
class CaptionRequest(BaseModel):
|
|
base64_image: str
|
|
length: str = "short"
|
|
|
|
class QueryRequest(BaseModel):
|
|
base64_image: str
|
|
question: str
|
|
stream: bool = False
|
|
|
|
@app.post("/caption")
|
|
async def generate_caption(payload: CaptionRequest):
|
|
try:
|
|
img = decode_base64_image(payload.base64_image)
|
|
response = model.caption(img, length=payload.length)
|
|
return {"caption": response["caption"]}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/query")
|
|
async def query_image(payload: QueryRequest):
|
|
try:
|
|
img = decode_base64_image(payload.base64_image)
|
|
|
|
if payload.stream:
|
|
def generate():
|
|
result = model.query(img, payload.question, stream=True)
|
|
for chunk in result["chunk"]:
|
|
yield chunk
|
|
return StreamingResponse(generate(), media_type="text/plain")
|
|
else:
|
|
result = model.query(img, payload.question)
|
|
return {"answer": result["answer"]}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|