Update wg_api.py

main
Kar 2025-04-29 14:33:05 +00:00
parent 90b56d1ef0
commit 3e60939ffc
1 changed files with 36 additions and 6 deletions

View File

@ -1,33 +1,55 @@
import subprocess import subprocess
import base64 import base64
from fastapi import FastAPI, Form, HTTPException, Query, Header, Depends from fastapi import FastAPI, Form, HTTPException, Query, Header, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pathlib import Path from pathlib import Path
import uvicorn import uvicorn
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
# Load environment variables # Rate limit support
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Load .env
load_dotenv() load_dotenv()
API_KEY_ENV = os.getenv("API_KEY") API_KEY_ENV = os.getenv("API_KEY")
WHITELIST = os.getenv("WHITELIST", "").split(",")
RATE_LIMIT = os.getenv("RATE_LIMIT", "5/minute") # e.g. 5 requests/minute
WG_DIR = Path("/etc/wireguard") WG_DIR = Path("/etc/wireguard")
SCRIPT_PATH = Path("/etc/wireguard/wg_config.sh") SCRIPT_PATH = Path("/etc/wireguard/wg_config.sh")
# Setup FastAPI and rate limiter
limiter = Limiter(key_func=get_remote_address)
app = FastAPI() app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Dependency to enforce API key
# --- Dependencies ---
def verify_api_key(x_api_key: str = Header(...)): def verify_api_key(x_api_key: str = Header(...)):
if x_api_key != API_KEY_ENV: if x_api_key != API_KEY_ENV:
raise HTTPException(status_code=403, detail="Invalid API key") raise HTTPException(status_code=403, detail="Invalid API key")
def verify_ip(request: Request):
client_ip = request.client.host
if client_ip not in WHITELIST:
raise HTTPException(status_code=403, detail=f"IP {client_ip} not allowed")
def is_valid_name(name): def is_valid_name(name):
return name.isalnum() return name.isalnum()
# --- Endpoints ---
@app.post("/vpn") @app.post("/vpn")
@limiter.limit(RATE_LIMIT)
async def create_vpn_client( async def create_vpn_client(
request: Request,
new: str = Form(...), new: str = Form(...),
auth: None = Depends(verify_api_key) _: None = Depends(verify_api_key),
__: None = Depends(verify_ip)
): ):
client_name = new.strip() client_name = new.strip()
if not is_valid_name(client_name): if not is_valid_name(client_name):
@ -68,9 +90,12 @@ async def create_vpn_client(
@app.delete("/vpn") @app.delete("/vpn")
@limiter.limit(RATE_LIMIT)
async def remove_vpn_client( async def remove_vpn_client(
request: Request,
remove: str = Query(...), remove: str = Query(...),
auth: None = Depends(verify_api_key) _: None = Depends(verify_api_key),
__: None = Depends(verify_ip)
): ):
client_name = remove.strip() client_name = remove.strip()
if not is_valid_name(client_name): if not is_valid_name(client_name):
@ -93,7 +118,12 @@ async def remove_vpn_client(
@app.get("/vpn/list") @app.get("/vpn/list")
async def list_vpn_clients(auth: None = Depends(verify_api_key)): @limiter.limit(RATE_LIMIT)
async def list_vpn_clients(
request: Request,
_: None = Depends(verify_api_key),
__: None = Depends(verify_ip)
):
clients = [] clients = []
for dir in WG_DIR.iterdir(): for dir in WG_DIR.iterdir():
if dir.is_dir(): if dir.is_dir():