remove domain whitelisting and JWT Authenticate

main
Subhodip Ghosh 2025-07-01 19:38:16 +05:30
parent ede388e184
commit b4e4bfe894
1 changed files with 36 additions and 144 deletions

180
main.go
View File

@ -2,7 +2,6 @@ package main
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
@ -10,31 +9,19 @@ import (
"log"
"net/http"
"os"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
_ "github.com/go-sql-driver/mysql"
"github.com/gorilla/mux"
"github.com/joho/godotenv"
"github.com/rs/cors"
)
var db *sql.DB
var jwtSecret []byte
type Config struct {
APIKeys map[string]APIKeyConfig `json:"api_keys"`
}
type APIKeyConfig struct {
Secret string `json:"secret"`
AllowedURLs []string `json:"allowed_urls"`
}
type ChatRequest struct {
Message string `json:"message"`
Domain string `json:"domain"`
UserID string `json:"user_id"` // Added user_id for basic tracking
}
type ChatResponse struct {
@ -61,10 +48,6 @@ func main() {
dbHost := os.Getenv("DB_HOST")
dbPort := os.Getenv("DB_PORT")
dbName := os.Getenv("DB_NAME")
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
jwtSecret = []byte(os.Getenv("JWT_SECRET"))
origins := strings.Split(allowedOrigins, ",")
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
db, err = sql.Open("mysql", dsn)
@ -75,21 +58,27 @@ func main() {
initDB()
c := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowCredentials: true,
})
router := mux.NewRouter()
router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST")
router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET")
router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS")
router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS")
handler := c.Handler(router)
// Simple CORS middleware
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
return
}
next.ServeHTTP(w, r)
})
}
log.Println("Server started on :8080")
log.Fatal(http.ListenAndServe(":8080", handler))
log.Fatal(http.ListenAndServe(":8080", corsMiddleware(router)))
}
func initDB() {
@ -108,105 +97,6 @@ func initDB() {
}
}
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Println("Authorization header missing")
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
log.Println("Bearer token not found in Authorization header")
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
return
}
unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err == nil {
if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok {
log.Printf("Token claims (unverified): %+v\n", claims)
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
if time.Now().After(expTime) {
log.Printf("Token expired at: %v\n", expTime)
http.Error(w, "Token expired", http.StatusUnauthorized)
return
}
}
}
} else {
log.Println("Unverified parse error:", err)
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil {
log.Println("Token verification failed:", err)
http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized)
return
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
userID, ok := claims["user_id"].(string)
if !ok || userID == "" {
log.Println("Invalid user_id in token claims")
http.Error(w, "Invalid user ID in token", http.StatusUnauthorized)
return
}
var allowedDomains []string
if domains, ok := claims["allowed_urls"].([]interface{}); ok {
for _, d := range domains {
if domain, ok := d.(string); ok {
allowedDomains = append(allowedDomains, domain)
}
}
}
var reqDomain string
if r.Method == "POST" {
var chatReq ChatRequest
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if err := json.Unmarshal(bodyBytes, &chatReq); err == nil {
reqDomain = chatReq.Domain
}
} else {
reqDomain = r.URL.Query().Get("domain")
}
domainAllowed := false
for _, domain := range allowedDomains {
if domain == reqDomain {
domainAllowed = true
break
}
}
if !domainAllowed {
log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains)
http.Error(w, "Domain not allowed", http.StatusForbidden)
return
}
ctx := context.WithValue(r.Context(), "userID", userID)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
log.Println("Invalid token claims")
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
}
}
}
func chatHandler(w http.ResponseWriter, r *http.Request) {
var req ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -214,10 +104,9 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
return
}
userID, ok := r.Context().Value("userID").(string)
if !ok {
http.Error(w, "User ID not found", http.StatusInternalServerError)
return
// Generate a simple user ID if not provided
if req.UserID == "" {
req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix())
}
response, err := getOllamaResponse(req.Message)
@ -229,7 +118,7 @@ func chatHandler(w http.ResponseWriter, r *http.Request) {
_, err = db.Exec(
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
userID, req.Message, response, req.Domain,
req.UserID, req.Message, response, req.Domain,
)
if err != nil {
log.Println("Database error:", err)
@ -253,7 +142,7 @@ func getOllamaResponse(prompt string) (string, error) {
}
requestData := OllamaRequest{
Model: "gemma3:1b", // Using the model you have installed
Model: "gemma3:1b",
Prompt: prompt + "(under 20 words)",
Stream: false,
}
@ -287,22 +176,25 @@ func getOllamaResponse(prompt string) (string, error) {
}
func historyHandler(w http.ResponseWriter, r *http.Request) {
userID, ok := r.Context().Value("userID").(string)
if !ok {
http.Error(w, "User ID not found", http.StatusInternalServerError)
return
}
userID := r.URL.Query().Get("user_id")
domain := r.URL.Query().Get("domain")
if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
rows, err := db.Query(
"SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50",
userID, domain,
)
query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?"
args := []interface{}{domain}
if userID != "" {
query += " AND user_id = ?"
args = append(args, userID)
}
query += " ORDER BY created_at DESC LIMIT 50"
rows, err := db.Query(query, args...)
if err != nil {
log.Println("Database error:", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)