Update wg_api.py
parent
90b56d1ef0
commit
3e60939ffc
42
wg_api.py
42
wg_api.py
|
@ -1,33 +1,55 @@
|
|||
import subprocess
|
||||
import base64
|
||||
from fastapi import FastAPI, Form, HTTPException, Query, Header, Depends
|
||||
from fastapi import FastAPI, Form, HTTPException, Query, Header, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pathlib import Path
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
# Load environment variables
|
||||
# Rate limit support
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# Load .env
|
||||
load_dotenv()
|
||||
API_KEY_ENV = os.getenv("API_KEY")
|
||||
WHITELIST = os.getenv("WHITELIST", "").split(",")
|
||||
RATE_LIMIT = os.getenv("RATE_LIMIT", "5/minute") # e.g. 5 requests/minute
|
||||
|
||||
WG_DIR = Path("/etc/wireguard")
|
||||
SCRIPT_PATH = Path("/etc/wireguard/wg_config.sh")
|
||||
|
||||
# Setup FastAPI and rate limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app = FastAPI()
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Dependency to enforce API key
|
||||
|
||||
# --- Dependencies ---
|
||||
def verify_api_key(x_api_key: str = Header(...)):
|
||||
if x_api_key != API_KEY_ENV:
|
||||
raise HTTPException(status_code=403, detail="Invalid API key")
|
||||
|
||||
def verify_ip(request: Request):
|
||||
client_ip = request.client.host
|
||||
if client_ip not in WHITELIST:
|
||||
raise HTTPException(status_code=403, detail=f"IP {client_ip} not allowed")
|
||||
|
||||
def is_valid_name(name):
|
||||
return name.isalnum()
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
@app.post("/vpn")
|
||||
@limiter.limit(RATE_LIMIT)
|
||||
async def create_vpn_client(
|
||||
request: Request,
|
||||
new: str = Form(...),
|
||||
auth: None = Depends(verify_api_key)
|
||||
_: None = Depends(verify_api_key),
|
||||
__: None = Depends(verify_ip)
|
||||
):
|
||||
client_name = new.strip()
|
||||
if not is_valid_name(client_name):
|
||||
|
@ -68,9 +90,12 @@ async def create_vpn_client(
|
|||
|
||||
|
||||
@app.delete("/vpn")
|
||||
@limiter.limit(RATE_LIMIT)
|
||||
async def remove_vpn_client(
|
||||
request: Request,
|
||||
remove: str = Query(...),
|
||||
auth: None = Depends(verify_api_key)
|
||||
_: None = Depends(verify_api_key),
|
||||
__: None = Depends(verify_ip)
|
||||
):
|
||||
client_name = remove.strip()
|
||||
if not is_valid_name(client_name):
|
||||
|
@ -93,7 +118,12 @@ async def remove_vpn_client(
|
|||
|
||||
|
||||
@app.get("/vpn/list")
|
||||
async def list_vpn_clients(auth: None = Depends(verify_api_key)):
|
||||
@limiter.limit(RATE_LIMIT)
|
||||
async def list_vpn_clients(
|
||||
request: Request,
|
||||
_: None = Depends(verify_api_key),
|
||||
__: None = Depends(verify_ip)
|
||||
):
|
||||
clients = []
|
||||
for dir in WG_DIR.iterdir():
|
||||
if dir.is_dir():
|
||||
|
|
Loading…
Reference in New Issue