From ede388e18429e2315b7909652801613f7270ee77 Mon Sep 17 00:00:00 2001 From: Suvodip Ghosh Date: Tue, 1 Jul 2025 17:17:50 +0530 Subject: [PATCH] first commit --- .gitignore | 1 + README.md | 0 go.mod | 14 +++ go.sum | 14 +++ main.go | 325 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 354 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2eea525 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.env \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2a73921 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module chatbot_api + +go 1.21.0 + +toolchain go1.22.3 + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect + github.com/go-sql-driver/mysql v1.9.3 // indirect + github.com/gorilla/mux v1.8.1 // indirect + github.com/joho/godotenv v1.5.1 // indirect + github.com/rs/cors v1.11.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3f83e6a --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/go-sql-driver/mysql v1.8.0 h1:UtktXaU2Nb64z/pLiGIxY4431SJ4/dR5cjMmlVHgnT4= +github.com/go-sql-driver/mysql v1.8.0/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= diff --git a/main.go b/main.go new file mode 100644 index 0000000..40c7a50 --- /dev/null +++ b/main.go @@ -0,0 +1,325 @@ +package main + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "github.com/dgrijalva/jwt-go" + _ "github.com/go-sql-driver/mysql" + "github.com/gorilla/mux" + "github.com/joho/godotenv" + "github.com/rs/cors" +) + +var db *sql.DB +var jwtSecret []byte + +type Config struct { + APIKeys map[string]APIKeyConfig `json:"api_keys"` +} + +type APIKeyConfig struct { + Secret string `json:"secret"` + AllowedURLs []string `json:"allowed_urls"` +} + +type ChatRequest struct { + Message string `json:"message"` + Domain string `json:"domain"` +} + +type ChatResponse struct { + Message string `json:"message"` +} + +type ChatMessage struct { + ID int `json:"id"` + UserID string `json:"user_id"` + Message string `json:"message"` + Response string `json:"response"` + Domain string `json:"domain"` + CreatedAt time.Time `json:"created_at"` +} + +func main() { + err := godotenv.Load() + if err != nil { + log.Fatal("Error loading .env file") + } + + dbUser := os.Getenv("DB_USER") + dbPass := os.Getenv("DB_PASS") + dbHost := os.Getenv("DB_HOST") + dbPort := os.Getenv("DB_PORT") + dbName := os.Getenv("DB_NAME") + allowedOrigins := os.Getenv("ALLOWED_ORIGINS") + jwtSecret = []byte(os.Getenv("JWT_SECRET")) + + origins := strings.Split(allowedOrigins, ",") + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName) + db, err = sql.Open("mysql", dsn) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + initDB() + + c := cors.New(cors.Options{ + AllowedOrigins: origins, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + AllowCredentials: true, + }) + + router := mux.NewRouter() + router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST") + router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET") + + handler := c.Handler(router) + + log.Println("Server started on :8080") + log.Fatal(http.ListenAndServe(":8080", handler)) +} + +func initDB() { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS chat_messages ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_id VARCHAR(255) NOT NULL, + message TEXT NOT NULL, + response TEXT NOT NULL, + domain VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + `) + if err != nil { + log.Fatal(err) + } +} + +func authMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + log.Println("Authorization header missing") + http.Error(w, "Authorization header required", http.StatusUnauthorized) + return + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + log.Println("Bearer token not found in Authorization header") + http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized) + return + } + + unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) + if err == nil { + if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok { + log.Printf("Token claims (unverified): %+v\n", claims) + + if exp, ok := claims["exp"].(float64); ok { + expTime := time.Unix(int64(exp), 0) + if time.Now().After(expTime) { + log.Printf("Token expired at: %v\n", expTime) + http.Error(w, "Token expired", http.StatusUnauthorized) + return + } + } + } + } else { + log.Println("Unverified parse error:", err) + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return jwtSecret, nil + }) + + if err != nil { + log.Println("Token verification failed:", err) + http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized) + return + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + userID, ok := claims["user_id"].(string) + if !ok || userID == "" { + log.Println("Invalid user_id in token claims") + http.Error(w, "Invalid user ID in token", http.StatusUnauthorized) + return + } + + var allowedDomains []string + if domains, ok := claims["allowed_urls"].([]interface{}); ok { + for _, d := range domains { + if domain, ok := d.(string); ok { + allowedDomains = append(allowedDomains, domain) + } + } + } + + var reqDomain string + if r.Method == "POST" { + var chatReq ChatRequest + bodyBytes, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + if err := json.Unmarshal(bodyBytes, &chatReq); err == nil { + reqDomain = chatReq.Domain + } + } else { + reqDomain = r.URL.Query().Get("domain") + } + + domainAllowed := false + for _, domain := range allowedDomains { + if domain == reqDomain { + domainAllowed = true + break + } + } + + if !domainAllowed { + log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains) + http.Error(w, "Domain not allowed", http.StatusForbidden) + return + } + + ctx := context.WithValue(r.Context(), "userID", userID) + next.ServeHTTP(w, r.WithContext(ctx)) + } else { + log.Println("Invalid token claims") + http.Error(w, "Invalid token claims", http.StatusUnauthorized) + } + } +} + +func chatHandler(w http.ResponseWriter, r *http.Request) { + var req ChatRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request: "+err.Error(), http.StatusBadRequest) + return + } + + userID, ok := r.Context().Value("userID").(string) + if !ok { + http.Error(w, "User ID not found", http.StatusInternalServerError) + return + } + + response, err := getOllamaResponse(req.Message) + if err != nil { + log.Println("Error getting response from Ollama:", err) + http.Error(w, "Error processing message", http.StatusInternalServerError) + return + } + + _, err = db.Exec( + "INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)", + userID, req.Message, response, req.Domain, + ) + if err != nil { + log.Println("Database error:", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ChatResponse{Message: response}) +} + +func getOllamaResponse(prompt string) (string, error) { + type OllamaRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` + } + + type OllamaResponse struct { + Response string `json:"response"` + } + + requestData := OllamaRequest{ + Model: "gemma3:1b", // Using the model you have installed + Prompt: prompt + "(under 20 words)", + Stream: false, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", fmt.Errorf("error creating request: %v", err) + } + + resp, err := http.Post( + "http://localhost:11434/api/generate", + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + return "", fmt.Errorf("error calling Ollama: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("Ollama API error: %s", string(body)) + } + + var ollamaResp OllamaResponse + if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil { + return "", fmt.Errorf("error decoding response: %v", err) + } + + return ollamaResp.Response, nil +} + +func historyHandler(w http.ResponseWriter, r *http.Request) { + userID, ok := r.Context().Value("userID").(string) + if !ok { + http.Error(w, "User ID not found", http.StatusInternalServerError) + return + } + + domain := r.URL.Query().Get("domain") + if domain == "" { + http.Error(w, "Domain parameter is required", http.StatusBadRequest) + return + } + + rows, err := db.Query( + "SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50", + userID, domain, + ) + if err != nil { + log.Println("Database error:", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + defer rows.Close() + + var messages []ChatMessage + for rows.Next() { + var msg ChatMessage + if err := rows.Scan(&msg.ID, &msg.UserID, &msg.Message, &msg.Response, &msg.Domain, &msg.CreatedAt); err != nil { + log.Println("Scan error:", err) + continue + } + messages = append(messages, msg) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(messages) +}