diff --git a/main.go b/main.go index 40c7a50..95023c7 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "context" "database/sql" "encoding/json" "fmt" @@ -10,31 +9,19 @@ import ( "log" "net/http" "os" - "strings" "time" - "github.com/dgrijalva/jwt-go" _ "github.com/go-sql-driver/mysql" "github.com/gorilla/mux" "github.com/joho/godotenv" - "github.com/rs/cors" ) var db *sql.DB -var jwtSecret []byte - -type Config struct { - APIKeys map[string]APIKeyConfig `json:"api_keys"` -} - -type APIKeyConfig struct { - Secret string `json:"secret"` - AllowedURLs []string `json:"allowed_urls"` -} type ChatRequest struct { Message string `json:"message"` Domain string `json:"domain"` + UserID string `json:"user_id"` // Added user_id for basic tracking } type ChatResponse struct { @@ -61,10 +48,6 @@ func main() { dbHost := os.Getenv("DB_HOST") dbPort := os.Getenv("DB_PORT") dbName := os.Getenv("DB_NAME") - allowedOrigins := os.Getenv("ALLOWED_ORIGINS") - jwtSecret = []byte(os.Getenv("JWT_SECRET")) - - origins := strings.Split(allowedOrigins, ",") dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName) db, err = sql.Open("mysql", dsn) @@ -75,21 +58,27 @@ func main() { initDB() - c := cors.New(cors.Options{ - AllowedOrigins: origins, - AllowedMethods: []string{"GET", "POST", "OPTIONS"}, - AllowedHeaders: []string{"Content-Type", "Authorization"}, - AllowCredentials: true, - }) - router := mux.NewRouter() - router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST") - router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET") + router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS") + router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS") - handler := c.Handler(router) + // Simple CORS middleware + corsMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + if r.Method == "OPTIONS" { + return + } + + next.ServeHTTP(w, r) + }) + } log.Println("Server started on :8080") - log.Fatal(http.ListenAndServe(":8080", handler)) + log.Fatal(http.ListenAndServe(":8080", corsMiddleware(router))) } func initDB() { @@ -108,105 +97,6 @@ func initDB() { } } -func authMiddleware(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - log.Println("Authorization header missing") - http.Error(w, "Authorization header required", http.StatusUnauthorized) - return - } - - tokenString := strings.TrimPrefix(authHeader, "Bearer ") - if tokenString == authHeader { - log.Println("Bearer token not found in Authorization header") - http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized) - return - } - - unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) - if err == nil { - if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok { - log.Printf("Token claims (unverified): %+v\n", claims) - - if exp, ok := claims["exp"].(float64); ok { - expTime := time.Unix(int64(exp), 0) - if time.Now().After(expTime) { - log.Printf("Token expired at: %v\n", expTime) - http.Error(w, "Token expired", http.StatusUnauthorized) - return - } - } - } - } else { - log.Println("Unverified parse error:", err) - } - - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return jwtSecret, nil - }) - - if err != nil { - log.Println("Token verification failed:", err) - http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized) - return - } - - if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - userID, ok := claims["user_id"].(string) - if !ok || userID == "" { - log.Println("Invalid user_id in token claims") - http.Error(w, "Invalid user ID in token", http.StatusUnauthorized) - return - } - - var allowedDomains []string - if domains, ok := claims["allowed_urls"].([]interface{}); ok { - for _, d := range domains { - if domain, ok := d.(string); ok { - allowedDomains = append(allowedDomains, domain) - } - } - } - - var reqDomain string - if r.Method == "POST" { - var chatReq ChatRequest - bodyBytes, _ := io.ReadAll(r.Body) - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - if err := json.Unmarshal(bodyBytes, &chatReq); err == nil { - reqDomain = chatReq.Domain - } - } else { - reqDomain = r.URL.Query().Get("domain") - } - - domainAllowed := false - for _, domain := range allowedDomains { - if domain == reqDomain { - domainAllowed = true - break - } - } - - if !domainAllowed { - log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains) - http.Error(w, "Domain not allowed", http.StatusForbidden) - return - } - - ctx := context.WithValue(r.Context(), "userID", userID) - next.ServeHTTP(w, r.WithContext(ctx)) - } else { - log.Println("Invalid token claims") - http.Error(w, "Invalid token claims", http.StatusUnauthorized) - } - } -} - func chatHandler(w http.ResponseWriter, r *http.Request) { var req ChatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -214,10 +104,9 @@ func chatHandler(w http.ResponseWriter, r *http.Request) { return } - userID, ok := r.Context().Value("userID").(string) - if !ok { - http.Error(w, "User ID not found", http.StatusInternalServerError) - return + // Generate a simple user ID if not provided + if req.UserID == "" { + req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix()) } response, err := getOllamaResponse(req.Message) @@ -229,7 +118,7 @@ func chatHandler(w http.ResponseWriter, r *http.Request) { _, err = db.Exec( "INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)", - userID, req.Message, response, req.Domain, + req.UserID, req.Message, response, req.Domain, ) if err != nil { log.Println("Database error:", err) @@ -253,7 +142,7 @@ func getOllamaResponse(prompt string) (string, error) { } requestData := OllamaRequest{ - Model: "gemma3:1b", // Using the model you have installed + Model: "gemma3:1b", Prompt: prompt + "(under 20 words)", Stream: false, } @@ -287,22 +176,25 @@ func getOllamaResponse(prompt string) (string, error) { } func historyHandler(w http.ResponseWriter, r *http.Request) { - userID, ok := r.Context().Value("userID").(string) - if !ok { - http.Error(w, "User ID not found", http.StatusInternalServerError) - return - } - + userID := r.URL.Query().Get("user_id") domain := r.URL.Query().Get("domain") + if domain == "" { http.Error(w, "Domain parameter is required", http.StatusBadRequest) return } - rows, err := db.Query( - "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50", - userID, domain, - ) + query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?" + args := []interface{}{domain} + + if userID != "" { + query += " AND user_id = ?" + args = append(args, userID) + } + + query += " ORDER BY created_at DESC LIMIT 50" + + rows, err := db.Query(query, args...) if err != nil { log.Println("Database error:", err) http.Error(w, "Internal server error", http.StatusInternalServerError)