remove domain whitelisting and JWT Authenticate
parent
ede388e184
commit
b4e4bfe894
180
main.go
180
main.go
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -10,31 +9,19 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/rs/cors"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
var jwtSecret []byte
|
||||
|
||||
type Config struct {
|
||||
APIKeys map[string]APIKeyConfig `json:"api_keys"`
|
||||
}
|
||||
|
||||
type APIKeyConfig struct {
|
||||
Secret string `json:"secret"`
|
||||
AllowedURLs []string `json:"allowed_urls"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message"`
|
||||
Domain string `json:"domain"`
|
||||
UserID string `json:"user_id"` // Added user_id for basic tracking
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
|
@ -61,10 +48,6 @@ func main() {
|
|||
dbHost := os.Getenv("DB_HOST")
|
||||
dbPort := os.Getenv("DB_PORT")
|
||||
dbName := os.Getenv("DB_NAME")
|
||||
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
||||
jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
||||
|
||||
origins := strings.Split(allowedOrigins, ",")
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
|
||||
db, err = sql.Open("mysql", dsn)
|
||||
|
@ -75,21 +58,27 @@ func main() {
|
|||
|
||||
initDB()
|
||||
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: origins,
|
||||
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||
AllowCredentials: true,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST")
|
||||
router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET")
|
||||
router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS")
|
||||
|
||||
handler := c.Handler(router)
|
||||
// Simple CORS middleware
|
||||
corsMiddleware := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
log.Println("Server started on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", handler))
|
||||
log.Fatal(http.ListenAndServe(":8080", corsMiddleware(router)))
|
||||
}
|
||||
|
||||
func initDB() {
|
||||
|
@ -108,105 +97,6 @@ func initDB() {
|
|||
}
|
||||
}
|
||||
|
||||
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
log.Println("Authorization header missing")
|
||||
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
log.Println("Bearer token not found in Authorization header")
|
||||
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
|
||||
if err == nil {
|
||||
if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok {
|
||||
log.Printf("Token claims (unverified): %+v\n", claims)
|
||||
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
if time.Now().After(expTime) {
|
||||
log.Printf("Token expired at: %v\n", expTime)
|
||||
http.Error(w, "Token expired", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Println("Unverified parse error:", err)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Println("Token verification failed:", err)
|
||||
http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
log.Println("Invalid user_id in token claims")
|
||||
http.Error(w, "Invalid user ID in token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var allowedDomains []string
|
||||
if domains, ok := claims["allowed_urls"].([]interface{}); ok {
|
||||
for _, d := range domains {
|
||||
if domain, ok := d.(string); ok {
|
||||
allowedDomains = append(allowedDomains, domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var reqDomain string
|
||||
if r.Method == "POST" {
|
||||
var chatReq ChatRequest
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
if err := json.Unmarshal(bodyBytes, &chatReq); err == nil {
|
||||
reqDomain = chatReq.Domain
|
||||
}
|
||||
} else {
|
||||
reqDomain = r.URL.Query().Get("domain")
|
||||
}
|
||||
|
||||
domainAllowed := false
|
||||
for _, domain := range allowedDomains {
|
||||
if domain == reqDomain {
|
||||
domainAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !domainAllowed {
|
||||
log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains)
|
||||
http.Error(w, "Domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "userID", userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
} else {
|
||||
log.Println("Invalid token claims")
|
||||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chatHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var req ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
|
@ -214,10 +104,9 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
userID, ok := r.Context().Value("userID").(string)
|
||||
if !ok {
|
||||
http.Error(w, "User ID not found", http.StatusInternalServerError)
|
||||
return
|
||||
// Generate a simple user ID if not provided
|
||||
if req.UserID == "" {
|
||||
req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
|
||||
response, err := getOllamaResponse(req.Message)
|
||||
|
@ -229,7 +118,7 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
|
||||
userID, req.Message, response, req.Domain,
|
||||
req.UserID, req.Message, response, req.Domain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Println("Database error:", err)
|
||||
|
@ -253,7 +142,7 @@ func getOllamaResponse(prompt string) (string, error) {
|
|||
}
|
||||
|
||||
requestData := OllamaRequest{
|
||||
Model: "gemma3:1b", // Using the model you have installed
|
||||
Model: "gemma3:1b",
|
||||
Prompt: prompt + "(under 20 words)",
|
||||
Stream: false,
|
||||
}
|
||||
|
@ -287,22 +176,25 @@ func getOllamaResponse(prompt string) (string, error) {
|
|||
}
|
||||
|
||||
func historyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := r.Context().Value("userID").(string)
|
||||
if !ok {
|
||||
http.Error(w, "User ID not found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
domain := r.URL.Query().Get("domain")
|
||||
|
||||
if domain == "" {
|
||||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
"SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50",
|
||||
userID, domain,
|
||||
)
|
||||
query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?"
|
||||
args := []interface{}{domain}
|
||||
|
||||
if userID != "" {
|
||||
query += " AND user_id = ?"
|
||||
args = append(args, userID)
|
||||
}
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT 50"
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
log.Println("Database error:", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
|
|
Loading…
Reference in New Issue