|
|
|
@ -0,0 +1,325 @@
|
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"context"
|
|
|
|
|
"database/sql"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"log"
|
|
|
|
|
"net/http"
|
|
|
|
|
"os"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/dgrijalva/jwt-go"
|
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
|
|
|
"github.com/gorilla/mux"
|
|
|
|
|
"github.com/joho/godotenv"
|
|
|
|
|
"github.com/rs/cors"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var db *sql.DB
|
|
|
|
|
var jwtSecret []byte
|
|
|
|
|
|
|
|
|
|
type Config struct {
|
|
|
|
|
APIKeys map[string]APIKeyConfig `json:"api_keys"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type APIKeyConfig struct {
|
|
|
|
|
Secret string `json:"secret"`
|
|
|
|
|
AllowedURLs []string `json:"allowed_urls"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type ChatRequest struct {
|
|
|
|
|
Message string `json:"message"`
|
|
|
|
|
Domain string `json:"domain"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type ChatResponse struct {
|
|
|
|
|
Message string `json:"message"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type ChatMessage struct {
|
|
|
|
|
ID int `json:"id"`
|
|
|
|
|
UserID string `json:"user_id"`
|
|
|
|
|
Message string `json:"message"`
|
|
|
|
|
Response string `json:"response"`
|
|
|
|
|
Domain string `json:"domain"`
|
|
|
|
|
CreatedAt time.Time `json:"created_at"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
|
err := godotenv.Load()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal("Error loading .env file")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dbUser := os.Getenv("DB_USER")
|
|
|
|
|
dbPass := os.Getenv("DB_PASS")
|
|
|
|
|
dbHost := os.Getenv("DB_HOST")
|
|
|
|
|
dbPort := os.Getenv("DB_PORT")
|
|
|
|
|
dbName := os.Getenv("DB_NAME")
|
|
|
|
|
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
|
|
|
|
jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
|
|
|
|
|
|
|
|
|
origins := strings.Split(allowedOrigins, ",")
|
|
|
|
|
|
|
|
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
|
|
|
|
|
db, err = sql.Open("mysql", dsn)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
defer db.Close()
|
|
|
|
|
|
|
|
|
|
initDB()
|
|
|
|
|
|
|
|
|
|
c := cors.New(cors.Options{
|
|
|
|
|
AllowedOrigins: origins,
|
|
|
|
|
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
|
|
|
|
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
|
|
|
|
AllowCredentials: true,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
router := mux.NewRouter()
|
|
|
|
|
router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST")
|
|
|
|
|
router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET")
|
|
|
|
|
|
|
|
|
|
handler := c.Handler(router)
|
|
|
|
|
|
|
|
|
|
log.Println("Server started on :8080")
|
|
|
|
|
log.Fatal(http.ListenAndServe(":8080", handler))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func initDB() {
|
|
|
|
|
_, err := db.Exec(`
|
|
|
|
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
|
|
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
|
|
|
user_id VARCHAR(255) NOT NULL,
|
|
|
|
|
message TEXT NOT NULL,
|
|
|
|
|
response TEXT NOT NULL,
|
|
|
|
|
domain VARCHAR(255) NOT NULL,
|
|
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
|
|
|
);
|
|
|
|
|
`)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
authHeader := r.Header.Get("Authorization")
|
|
|
|
|
if authHeader == "" {
|
|
|
|
|
log.Println("Authorization header missing")
|
|
|
|
|
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
|
if tokenString == authHeader {
|
|
|
|
|
log.Println("Bearer token not found in Authorization header")
|
|
|
|
|
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
|
|
|
|
|
if err == nil {
|
|
|
|
|
if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok {
|
|
|
|
|
log.Printf("Token claims (unverified): %+v\n", claims)
|
|
|
|
|
|
|
|
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
|
|
|
expTime := time.Unix(int64(exp), 0)
|
|
|
|
|
if time.Now().After(expTime) {
|
|
|
|
|
log.Printf("Token expired at: %v\n", expTime)
|
|
|
|
|
http.Error(w, "Token expired", http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
log.Println("Unverified parse error:", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
}
|
|
|
|
|
return jwtSecret, nil
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Println("Token verification failed:", err)
|
|
|
|
|
http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
|
|
|
userID, ok := claims["user_id"].(string)
|
|
|
|
|
if !ok || userID == "" {
|
|
|
|
|
log.Println("Invalid user_id in token claims")
|
|
|
|
|
http.Error(w, "Invalid user ID in token", http.StatusUnauthorized)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var allowedDomains []string
|
|
|
|
|
if domains, ok := claims["allowed_urls"].([]interface{}); ok {
|
|
|
|
|
for _, d := range domains {
|
|
|
|
|
if domain, ok := d.(string); ok {
|
|
|
|
|
allowedDomains = append(allowedDomains, domain)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var reqDomain string
|
|
|
|
|
if r.Method == "POST" {
|
|
|
|
|
var chatReq ChatRequest
|
|
|
|
|
bodyBytes, _ := io.ReadAll(r.Body)
|
|
|
|
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
|
if err := json.Unmarshal(bodyBytes, &chatReq); err == nil {
|
|
|
|
|
reqDomain = chatReq.Domain
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
reqDomain = r.URL.Query().Get("domain")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
domainAllowed := false
|
|
|
|
|
for _, domain := range allowedDomains {
|
|
|
|
|
if domain == reqDomain {
|
|
|
|
|
domainAllowed = true
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !domainAllowed {
|
|
|
|
|
log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains)
|
|
|
|
|
http.Error(w, "Domain not allowed", http.StatusForbidden)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx := context.WithValue(r.Context(), "userID", userID)
|
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
|
} else {
|
|
|
|
|
log.Println("Invalid token claims")
|
|
|
|
|
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func chatHandler(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
var req ChatRequest
|
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
|
|
|
http.Error(w, "Invalid request: "+err.Error(), http.StatusBadRequest)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
userID, ok := r.Context().Value("userID").(string)
|
|
|
|
|
if !ok {
|
|
|
|
|
http.Error(w, "User ID not found", http.StatusInternalServerError)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
response, err := getOllamaResponse(req.Message)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Println("Error getting response from Ollama:", err)
|
|
|
|
|
http.Error(w, "Error processing message", http.StatusInternalServerError)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = db.Exec(
|
|
|
|
|
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
|
|
|
|
|
userID, req.Message, response, req.Domain,
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Println("Database error:", err)
|
|
|
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
|
json.NewEncoder(w).Encode(ChatResponse{Message: response})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func getOllamaResponse(prompt string) (string, error) {
|
|
|
|
|
type OllamaRequest struct {
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
Prompt string `json:"prompt"`
|
|
|
|
|
Stream bool `json:"stream"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type OllamaResponse struct {
|
|
|
|
|
Response string `json:"response"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
requestData := OllamaRequest{
|
|
|
|
|
Model: "gemma3:1b", // Using the model you have installed
|
|
|
|
|
Prompt: prompt + "(under 20 words)",
|
|
|
|
|
Stream: false,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
requestBody, err := json.Marshal(requestData)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("error creating request: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
resp, err := http.Post(
|
|
|
|
|
"http://localhost:11434/api/generate",
|
|
|
|
|
"application/json",
|
|
|
|
|
bytes.NewBuffer(requestBody),
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("error calling Ollama: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
|
return "", fmt.Errorf("Ollama API error: %s", string(body))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var ollamaResp OllamaResponse
|
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
|
|
|
|
return "", fmt.Errorf("error decoding response: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ollamaResp.Response, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func historyHandler(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
userID, ok := r.Context().Value("userID").(string)
|
|
|
|
|
if !ok {
|
|
|
|
|
http.Error(w, "User ID not found", http.StatusInternalServerError)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
domain := r.URL.Query().Get("domain")
|
|
|
|
|
if domain == "" {
|
|
|
|
|
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rows, err := db.Query(
|
|
|
|
|
"SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50",
|
|
|
|
|
userID, domain,
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Println("Database error:", err)
|
|
|
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
|
|
|
|
var messages []ChatMessage
|
|
|
|
|
for rows.Next() {
|
|
|
|
|
var msg ChatMessage
|
|
|
|
|
if err := rows.Scan(&msg.ID, &msg.UserID, &msg.Message, &msg.Response, &msg.Domain, &msg.CreatedAt); err != nil {
|
|
|
|
|
log.Println("Scan error:", err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
messages = append(messages, msg)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
|
json.NewEncoder(w).Encode(messages)
|
|
|
|
|
}
|