lipi/main.go

218 lines
5.2 KiB
Go

package main
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/gorilla/mux"
"github.com/joho/godotenv"
)
var db *sql.DB
type ChatRequest struct {
Message string `json:"message"`
Domain string `json:"domain"`
UserID string `json:"user_id"` // Added user_id for basic tracking
}
type ChatResponse struct {
Message string `json:"message"`
}
type ChatMessage struct {
ID int `json:"id"`
UserID string `json:"user_id"`
Message string `json:"message"`
Response string `json:"response"`
Domain string `json:"domain"`
CreatedAt time.Time `json:"created_at"`
}
func main() {
err := godotenv.Load()
if err != nil {
log.Fatal("Error loading .env file")
}
dbUser := os.Getenv("DB_USER")
dbPass := os.Getenv("DB_PASS")
dbHost := os.Getenv("DB_HOST")
dbPort := os.Getenv("DB_PORT")
dbName := os.Getenv("DB_NAME")
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
db, err = sql.Open("mysql", dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
initDB()
router := mux.NewRouter()
router.HandleFunc("/chat", chatHandler).Methods("POST", "OPTIONS")
router.HandleFunc("/history", historyHandler).Methods("GET", "OPTIONS")
// Simple CORS middleware
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
return
}
next.ServeHTTP(w, r)
})
}
log.Println("Server started on :4012")
log.Fatal(http.ListenAndServe(":4012", corsMiddleware(router)))
}
func initDB() {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS chat_messages (
id INT AUTO_INCREMENT PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
message TEXT NOT NULL,
response TEXT NOT NULL,
domain VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
`)
if err != nil {
log.Fatal(err)
}
}
func chatHandler(w http.ResponseWriter, r *http.Request) {
var req ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request: "+err.Error(), http.StatusBadRequest)
return
}
// Generate a simple user ID if not provided
if req.UserID == "" {
req.UserID = "guest_" + fmt.Sprintf("%d", time.Now().Unix())
}
response, err := getOllamaResponse(req.Message)
if err != nil {
log.Println("Error getting response from Ollama:", err)
http.Error(w, "Error processing message", http.StatusInternalServerError)
return
}
_, err = db.Exec(
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
req.UserID, req.Message, response, req.Domain,
)
if err != nil {
log.Println("Database error:", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ChatResponse{Message: response})
}
func getOllamaResponse(prompt string) (string, error) {
type OllamaRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
}
type OllamaResponse struct {
Response string `json:"response"`
}
requestData := OllamaRequest{
Model: "gemma3:1b",
Prompt: prompt + "(under 20 words)",
Stream: false,
}
requestBody, err := json.Marshal(requestData)
if err != nil {
return "", fmt.Errorf("error creating request: %v", err)
}
resp, err := http.Post(
"http://localhost:11434/api/generate",
"application/json",
bytes.NewBuffer(requestBody),
)
if err != nil {
return "", fmt.Errorf("error calling Ollama: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("Ollama API error: %s", string(body))
}
var ollamaResp OllamaResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return "", fmt.Errorf("error decoding response: %v", err)
}
return ollamaResp.Response, nil
}
func historyHandler(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
domain := r.URL.Query().Get("domain")
if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
query := "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE domain = ?"
args := []interface{}{domain}
if userID != "" {
query += " AND user_id = ?"
args = append(args, userID)
}
query += " ORDER BY created_at DESC LIMIT 50"
rows, err := db.Query(query, args...)
if err != nil {
log.Println("Database error:", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer rows.Close()
var messages []ChatMessage
for rows.Next() {
var msg ChatMessage
if err := rows.Scan(&msg.ID, &msg.UserID, &msg.Message, &msg.Response, &msg.Domain, &msg.CreatedAt); err != nil {
log.Println("Scan error:", err)
continue
}
messages = append(messages, msg)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(messages)
}