From 3e60939ffc90ab7f7358dcc7d60d8c0e2a7e431c Mon Sep 17 00:00:00 2001 From: Kar Date: Tue, 29 Apr 2025 14:33:05 +0000 Subject: [PATCH] Update wg_api.py --- wg_api.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/wg_api.py b/wg_api.py index 3a82f44..2ee743f 100644 --- a/wg_api.py +++ b/wg_api.py @@ -1,33 +1,55 @@ import subprocess import base64 -from fastapi import FastAPI, Form, HTTPException, Query, Header, Depends +from fastapi import FastAPI, Form, HTTPException, Query, Header, Depends, Request from fastapi.responses import JSONResponse from pathlib import Path import uvicorn from dotenv import load_dotenv import os -# Load environment variables +# Rate limit support +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded + +# Load .env load_dotenv() API_KEY_ENV = os.getenv("API_KEY") +WHITELIST = os.getenv("WHITELIST", "").split(",") +RATE_LIMIT = os.getenv("RATE_LIMIT", "5/minute") # e.g. 5 requests/minute WG_DIR = Path("/etc/wireguard") SCRIPT_PATH = Path("/etc/wireguard/wg_config.sh") +# Setup FastAPI and rate limiter +limiter = Limiter(key_func=get_remote_address) app = FastAPI() +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) -# Dependency to enforce API key + +# --- Dependencies --- def verify_api_key(x_api_key: str = Header(...)): if x_api_key != API_KEY_ENV: raise HTTPException(status_code=403, detail="Invalid API key") +def verify_ip(request: Request): + client_ip = request.client.host + if client_ip not in WHITELIST: + raise HTTPException(status_code=403, detail=f"IP {client_ip} not allowed") + def is_valid_name(name): return name.isalnum() + +# --- Endpoints --- @app.post("/vpn") +@limiter.limit(RATE_LIMIT) async def create_vpn_client( + request: Request, new: str = Form(...), - auth: None = Depends(verify_api_key) + _: None = Depends(verify_api_key), + __: None = Depends(verify_ip) ): client_name = new.strip() if not is_valid_name(client_name): @@ -68,9 +90,12 @@ async def create_vpn_client( @app.delete("/vpn") +@limiter.limit(RATE_LIMIT) async def remove_vpn_client( + request: Request, remove: str = Query(...), - auth: None = Depends(verify_api_key) + _: None = Depends(verify_api_key), + __: None = Depends(verify_ip) ): client_name = remove.strip() if not is_valid_name(client_name): @@ -93,7 +118,12 @@ async def remove_vpn_client( @app.get("/vpn/list") -async def list_vpn_clients(auth: None = Depends(verify_api_key)): +@limiter.limit(RATE_LIMIT) +async def list_vpn_clients( + request: Request, + _: None = Depends(verify_api_key), + __: None = Depends(verify_ip) +): clients = [] for dir in WG_DIR.iterdir(): if dir.is_dir():