remove domain whitelisting and JWT Authenticate

main
Subhodip Ghosh 2025-07-01 19:38:16 +05:30
parent ede388e184
commit b4e4bfe894
1 changed files with 36 additions and 144 deletions

180
main.go
View File

@ -2,7 +2,6 @@ package main
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -10,31 +9,19 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/dgrijalva/jwt-go"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/rs/cors"
) )
var db *sql.DB var db *sql.DB
var jwtSecret []byte
type Config struct {
APIKeys map[string]APIKeyConfig `json:"api_keys"`
}
type APIKeyConfig struct {
Secret string `json:"secret"`
AllowedURLs []string `json:"allowed_urls"`
}
type ChatRequest struct { type ChatRequest struct {
Message string `json:"message"` Message string `json:"message"`
Domain string `json:"domain"` Domain string `json:"domain"`
UserID string `json:"user_id"` // Added user_id for basic tracking
} }
type ChatResponse struct { type ChatResponse struct {
@ -61,10 +48,6 @@ func main() {
dbHost := os.Getenv("DB_HOST") dbHost := os.Getenv("DB_HOST")
dbPort := os.Getenv("DB_PORT") dbPort := os.Getenv("DB_PORT")
dbName := os.Getenv("DB_NAME") dbName := os.Getenv("DB_NAME")
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
jwtSecret = []byte(os.Getenv("JWT_SECRET"))
origins := strings.Split(allowedOrigins, ",")
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName) dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
db, err = sql.Open("mysql", dsn) db, err = sql.Open("mysql", dsn)
@ -75,21 +58,27 @@ func main() {
initDB() initDB()
c := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowCredentials: true,
})
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST") router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS")
router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET") router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS")
handler := c.Handler(router) // Simple CORS middleware
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
return
}
next.ServeHTTP(w, r)
})
}
log.Println("Server started on :8080") log.Println("Server started on :8080")
log.Fatal(http.ListenAndServe(":8080", handler)) log.Fatal(http.ListenAndServe(":8080", corsMiddleware(router)))
} }
func initDB() { func initDB() {
@ -108,105 +97,6 @@ func initDB() {
} }
} }
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Println("Authorization header missing")
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
log.Println("Bearer token not found in Authorization header")
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
return
}
unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err == nil {
if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok {
log.Printf("Token claims (unverified): %+v\n", claims)
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
if time.Now().After(expTime) {
log.Printf("Token expired at: %v\n", expTime)
http.Error(w, "Token expired", http.StatusUnauthorized)
return
}
}
}
} else {
log.Println("Unverified parse error:", err)
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil {
log.Println("Token verification failed:", err)
http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized)
return
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
userID, ok := claims["user_id"].(string)
if !ok || userID == "" {
log.Println("Invalid user_id in token claims")
http.Error(w, "Invalid user ID in token", http.StatusUnauthorized)
return
}
var allowedDomains []string
if domains, ok := claims["allowed_urls"].([]interface{}); ok {
for _, d := range domains {
if domain, ok := d.(string); ok {
allowedDomains = append(allowedDomains, domain)
}
}
}
var reqDomain string
if r.Method == "POST" {
var chatReq ChatRequest
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if err := json.Unmarshal(bodyBytes, &chatReq); err == nil {
reqDomain = chatReq.Domain
}
} else {
reqDomain = r.URL.Query().Get("domain")
}
domainAllowed := false
for _, domain := range allowedDomains {
if domain == reqDomain {
domainAllowed = true
break
}
}
if !domainAllowed {
log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains)
http.Error(w, "Domain not allowed", http.StatusForbidden)
return
}
ctx := context.WithValue(r.Context(), "userID", userID)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
log.Println("Invalid token claims")
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
}
}
}
func chatHandler(w http.ResponseWriter, r *http.Request) { func chatHandler(w http.ResponseWriter, r *http.Request) {
var req ChatRequest var req ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -214,10 +104,9 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
userID, ok := r.Context().Value("userID").(string) // Generate a simple user ID if not provided
if !ok { if req.UserID == "" {
http.Error(w, "User ID not found", http.StatusInternalServerError) req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix())
return
} }
response, err := getOllamaResponse(req.Message) response, err := getOllamaResponse(req.Message)
@ -229,7 +118,7 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
_, err = db.Exec( _, err = db.Exec(
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)", "INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
userID, req.Message, response, req.Domain, req.UserID, req.Message, response, req.Domain,
) )
if err != nil { if err != nil {
log.Println("Database error:", err) log.Println("Database error:", err)
@ -253,7 +142,7 @@ func getOllamaResponse(prompt string) (string, error) {
} }
requestData := OllamaRequest{ requestData := OllamaRequest{
Model: "gemma3:1b", // Using the model you have installed Model: "gemma3:1b",
Prompt: prompt + "(under 20 words)", Prompt: prompt + "(under 20 words)",
Stream: false, Stream: false,
} }
@ -287,22 +176,25 @@ func getOllamaResponse(prompt string) (string, error) {
} }
func historyHandler(w http.ResponseWriter, r *http.Request) { func historyHandler(w http.ResponseWriter, r *http.Request) {
userID, ok := r.Context().Value("userID").(string) userID := r.URL.Query().Get("user_id")
if !ok {
http.Error(w, "User ID not found", http.StatusInternalServerError)
return
}
domain := r.URL.Query().Get("domain") domain := r.URL.Query().Get("domain")
if domain == "" { if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest) http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return return
} }
rows, err := db.Query( query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?"
"SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50", args := []interface{}{domain}
userID, domain,
) if userID != "" {
query += " AND user_id = ?"
args = append(args, userID)
}
query += " ORDER BY created_at DESC LIMIT 50"
rows, err := db.Query(query, args...)
if err != nil { if err != nil {
log.Println("Database error:", err) log.Println("Database error:", err)
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)