218 lines
5.2 KiB
Go
218 lines
5.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/gorilla/mux"
|
|
"github.com/joho/godotenv"
|
|
)
|
|
|
|
var db *sql.DB
|
|
|
|
type ChatRequest struct {
|
|
Message string `json:"message"`
|
|
Domain string `json:"domain"`
|
|
UserID string `json:"user_id"` // Added user_id for basic tracking
|
|
}
|
|
|
|
type ChatResponse struct {
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
type ChatMessage struct {
|
|
ID int `json:"id"`
|
|
UserID string `json:"user_id"`
|
|
Message string `json:"message"`
|
|
Response string `json:"response"`
|
|
Domain string `json:"domain"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
func main() {
|
|
err := godotenv.Load()
|
|
if err != nil {
|
|
log.Fatal("Error loading .env file")
|
|
}
|
|
|
|
dbUser := os.Getenv("DB_USER")
|
|
dbPass := os.Getenv("DB_PASS")
|
|
dbHost := os.Getenv("DB_HOST")
|
|
dbPort := os.Getenv("DB_PORT")
|
|
dbName := os.Getenv("DB_NAME")
|
|
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
|
|
db, err = sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer db.Close()
|
|
|
|
initDB()
|
|
|
|
router := mux.NewRouter()
|
|
router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS")
|
|
router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS")
|
|
|
|
// Simple CORS middleware
|
|
corsMiddleware := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
log.Println("Server started on :8080")
|
|
log.Fatal(http.ListenAndServe(":8080", corsMiddleware(router)))
|
|
}
|
|
|
|
func initDB() {
|
|
_, err := db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
user_id VARCHAR(255) NOT NULL,
|
|
message TEXT NOT NULL,
|
|
response TEXT NOT NULL,
|
|
domain VARCHAR(255) NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
`)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func chatHandler(w http.ResponseWriter, r *http.Request) {
|
|
var req ChatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request: "+err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Generate a simple user ID if not provided
|
|
if req.UserID == "" {
|
|
req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix())
|
|
}
|
|
|
|
response, err := getOllamaResponse(req.Message)
|
|
if err != nil {
|
|
log.Println("Error getting response from Ollama:", err)
|
|
http.Error(w, "Error processing message", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
_, err = db.Exec(
|
|
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
|
|
req.UserID, req.Message, response, req.Domain,
|
|
)
|
|
if err != nil {
|
|
log.Println("Database error:", err)
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(ChatResponse{Message: response})
|
|
}
|
|
|
|
func getOllamaResponse(prompt string) (string, error) {
|
|
type OllamaRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type OllamaResponse struct {
|
|
Response string `json:"response"`
|
|
}
|
|
|
|
requestData := OllamaRequest{
|
|
Model: "gemma3:1b",
|
|
Prompt: prompt + "(under 20 words)",
|
|
Stream: false,
|
|
}
|
|
|
|
requestBody, err := json.Marshal(requestData)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating request: %v", err)
|
|
}
|
|
|
|
resp, err := http.Post(
|
|
"http://localhost:11434/api/generate",
|
|
"application/json",
|
|
bytes.NewBuffer(requestBody),
|
|
)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error calling Ollama: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("Ollama API error: %s", string(body))
|
|
}
|
|
|
|
var ollamaResp OllamaResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
|
return "", fmt.Errorf("error decoding response: %v", err)
|
|
}
|
|
|
|
return ollamaResp.Response, nil
|
|
}
|
|
|
|
func historyHandler(w http.ResponseWriter, r *http.Request) {
|
|
userID := r.URL.Query().Get("user_id")
|
|
domain := r.URL.Query().Get("domain")
|
|
|
|
if domain == "" {
|
|
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?"
|
|
args := []interface{}{domain}
|
|
|
|
if userID != "" {
|
|
query += " AND user_id = ?"
|
|
args = append(args, userID)
|
|
}
|
|
|
|
query += " ORDER BY created_at DESC LIMIT 50"
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
log.Println("Database error:", err)
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var messages []ChatMessage
|
|
for rows.Next() {
|
|
var msg ChatMessage
|
|
if err := rows.Scan(&msg.ID, &msg.UserID, &msg.Message, &msg.Response, &msg.Domain, &msg.CreatedAt); err != nil {
|
|
log.Println("Scan error:", err)
|
|
continue
|
|
}
|
|
messages = append(messages, msg)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(messages)
|
|
}
|