package main import ( "bytes" "database/sql" "encoding/json" "fmt" "io" "log" "net/http" "os" "time" _ "github.com/go-sql-driver/mysql" "github.com/gorilla/mux" "github.com/joho/godotenv" ) var db *sql.DB type ChatRequest struct { Message string `json:"message"` Domain string `json:"domain"` UserID string `json:"user_id"` // Added user_id for basic tracking } type ChatResponse struct { Message string `json:"message"` } type ChatMessage struct { ID int `json:"id"` UserID string `json:"user_id"` Message string `json:"message"` Response string `json:"response"` Domain string `json:"domain"` CreatedAt time.Time `json:"created_at"` } func main() { err := godotenv.Load() if err != nil { log.Fatal("Error loading .env file") } dbUser := os.Getenv("DB_USER") dbPass := os.Getenv("DB_PASS") dbHost := os.Getenv("DB_HOST") dbPort := os.Getenv("DB_PORT") dbName := os.Getenv("DB_NAME") dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName) db, err = sql.Open("mysql", dsn) if err != nil { log.Fatal(err) } defer db.Close() initDB() router := mux.NewRouter() router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS") router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS") // Simple CORS middleware corsMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") if r.Method == "OPTIONS" { return } next.ServeHTTP(w, r) }) } log.Println("Server started on :4012") log.Fatal(http.ListenAndServe(":4012", corsMiddleware(router))) } func initDB() { _, err := db.Exec(` CREATE TABLE IF NOT EXISTS chat_messages ( id INT AUTO_INCREMENT PRIMARY KEY, user_id VARCHAR(255) NOT NULL, message TEXT NOT NULL, response TEXT NOT NULL, domain VARCHAR(255) NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); `) if err != nil { log.Fatal(err) } } func chatHandler(w http.ResponseWriter, r *http.Request) { var req ChatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request: "+err.Error(), http.StatusBadRequest) return } // Generate a simple user ID if not provided if req.UserID == "" { req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix()) } response, err := getOllamaResponse(req.Message) if err != nil { log.Println("Error getting response from Ollama:", err) http.Error(w, "Error processing message", http.StatusInternalServerError) return } _, err = db.Exec( "INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)", req.UserID, req.Message, response, req.Domain, ) if err != nil { log.Println("Database error:", err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(ChatResponse{Message: response}) } func getOllamaResponse(prompt string) (string, error) { type OllamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` Stream bool `json:"stream"` } type OllamaResponse struct { Response string `json:"response"` } requestData := OllamaRequest{ Model: "gemma3:1b", Prompt: prompt + "(under 20 words)", Stream: false, } requestBody, err := json.Marshal(requestData) if err != nil { return "", fmt.Errorf("error creating request: %v", err) } resp, err := http.Post( "http://localhost:11434/api/generate", "application/json", bytes.NewBuffer(requestBody), ) if err != nil { return "", fmt.Errorf("error calling Ollama: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("Ollama API error: %s", string(body)) } var ollamaResp OllamaResponse if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { return "", fmt.Errorf("error decoding response: %v", err) } return ollamaResp.Response, nil } func historyHandler(w http.ResponseWriter, r *http.Request) { userID := r.URL.Query().Get("user_id") domain := r.URL.Query().Get("domain") if domain == "" { http.Error(w, "Domain parameter is required", http.StatusBadRequest) return } query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?" args := []interface{}{domain} if userID != "" { query += " AND user_id = ?" args = append(args, userID) } query += " ORDER BY created_at DESC LIMIT 50" rows, err := db.Query(query, args...) if err != nil { log.Println("Database error:", err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } defer rows.Close() var messages []ChatMessage for rows.Next() { var msg ChatMessage if err := rows.Scan(&msg.ID, &msg.UserID, &msg.Message, &msg.Response, &msg.Domain, &msg.CreatedAt); err != nil { log.Println("Scan error:", err) continue } messages = append(messages, msg) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(messages) }