first commit
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.env
|
||||
14
go.mod
Normal file
14
go.mod
Normal file
@@ -0,0 +1,14 @@
|
||||
module chatbot_api
|
||||
|
||||
go 1.21.0
|
||||
|
||||
toolchain go1.22.3
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/gorilla/mux v1.8.1 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/rs/cors v1.11.1 // indirect
|
||||
)
|
||||
14
go.sum
Normal file
14
go.sum
Normal file
@@ -0,0 +1,14 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/go-sql-driver/mysql v1.8.0 h1:UtktXaU2Nb64z/pLiGIxY4431SJ4/dR5cjMmlVHgnT4=
|
||||
github.com/go-sql-driver/mysql v1.8.0/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
|
||||
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
325
main.go
Normal file
325
main.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/rs/cors"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
var jwtSecret []byte
|
||||
|
||||
type Config struct {
|
||||
APIKeys map[string]APIKeyConfig `json:"api_keys"`
|
||||
}
|
||||
|
||||
type APIKeyConfig struct {
|
||||
Secret string `json:"secret"`
|
||||
AllowedURLs []string `json:"allowed_urls"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Message string `json:"message"`
|
||||
Response string `json:"response"`
|
||||
Domain string `json:"domain"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
log.Fatal("Error loading .env file")
|
||||
}
|
||||
|
||||
dbUser := os.Getenv("DB_USER")
|
||||
dbPass := os.Getenv("DB_PASS")
|
||||
dbHost := os.Getenv("DB_HOST")
|
||||
dbPort := os.Getenv("DB_PORT")
|
||||
dbName := os.Getenv("DB_NAME")
|
||||
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
||||
jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
||||
|
||||
origins := strings.Split(allowedOrigins, ",")
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbHost, dbPort, dbName)
|
||||
db, err = sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
initDB()
|
||||
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: origins,
|
||||
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||
AllowCredentials: true,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/chat", authMiddleware(chatHandler)).Methods("POST")
|
||||
router.HandleFunc("/history", authMiddleware(historyHandler)).Methods("GET")
|
||||
|
||||
handler := c.Handler(router)
|
||||
|
||||
log.Println("Server started on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", handler))
|
||||
}
|
||||
|
||||
func initDB() {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
response TEXT NOT NULL,
|
||||
domain VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
log.Println("Authorization header missing")
|
||||
http.Error(w, "Authorization header required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
log.Println("Bearer token not found in Authorization header")
|
||||
http.Error(w, "Invalid Authorization header format", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
unverifiedToken, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
|
||||
if err == nil {
|
||||
if claims, ok := unverifiedToken.Claims.(jwt.MapClaims); ok {
|
||||
log.Printf("Token claims (unverified): %+v\n", claims)
|
||||
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
if time.Now().After(expTime) {
|
||||
log.Printf("Token expired at: %v\n", expTime)
|
||||
http.Error(w, "Token expired", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Println("Unverified parse error:", err)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Println("Token verification failed:", err)
|
||||
http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
userID, ok := claims["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
log.Println("Invalid user_id in token claims")
|
||||
http.Error(w, "Invalid user ID in token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var allowedDomains []string
|
||||
if domains, ok := claims["allowed_urls"].([]interface{}); ok {
|
||||
for _, d := range domains {
|
||||
if domain, ok := d.(string); ok {
|
||||
allowedDomains = append(allowedDomains, domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var reqDomain string
|
||||
if r.Method == "POST" {
|
||||
var chatReq ChatRequest
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
if err := json.Unmarshal(bodyBytes, &chatReq); err == nil {
|
||||
reqDomain = chatReq.Domain
|
||||
}
|
||||
} else {
|
||||
reqDomain = r.URL.Query().Get("domain")
|
||||
}
|
||||
|
||||
domainAllowed := false
|
||||
for _, domain := range allowedDomains {
|
||||
if domain == reqDomain {
|
||||
domainAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !domainAllowed {
|
||||
log.Printf("Domain not allowed. Requested: %v, Allowed: %v\n", reqDomain, allowedDomains)
|
||||
http.Error(w, "Domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "userID", userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
} else {
|
||||
log.Println("Invalid token claims")
|
||||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chatHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var req ChatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := r.Context().Value("userID").(string)
|
||||
if !ok {
|
||||
http.Error(w, "User ID not found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := getOllamaResponse(req.Message)
|
||||
if err != nil {
|
||||
log.Println("Error getting response from Ollama:", err)
|
||||
http.Error(w, "Error processing message", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO chat_messages (user_id, message, response, domain) VALUES (?, ?, ?, ?)",
|
||||
userID, req.Message, response, req.Domain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Println("Database error:", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(ChatResponse{Message: response})
|
||||
}
|
||||
|
||||
func getOllamaResponse(prompt string) (string, error) {
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OllamaResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
requestData := OllamaRequest{
|
||||
Model: "gemma3:1b", // Using the model you have installed
|
||||
Prompt: prompt + "(under 20 words)",
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(
|
||||
"http://localhost:11434/api/generate",
|
||||
"application/json",
|
||||
bytes.NewBuffer(requestBody),
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error calling Ollama: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("Ollama API error: %s", string(body))
|
||||
}
|
||||
|
||||
var ollamaResp OllamaResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
||||
return "", fmt.Errorf("error decoding response: %v", err)
|
||||
}
|
||||
|
||||
return ollamaResp.Response, nil
|
||||
}
|
||||
|
||||
func historyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := r.Context().Value("userID").(string)
|
||||
if !ok {
|
||||
http.Error(w, "User ID not found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
domain := r.URL.Query().Get("domain")
|
||||
if domain == "" {
|
||||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
"SELECT id, user_id, message, response, domain, created_at FROM chat_messages WHERE user_id = ? AND domain = ? ORDER BY created_at DESC LIMIT 50",
|
||||
userID, domain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Println("Database error:", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []ChatMessage
|
||||
for rows.Next() {
|
||||
var msg ChatMessage
|
||||
if err := rows.Scan(&msg.ID, &msg.UserID, &msg.Message, &msg.Response, &msg.Domain, &msg.CreatedAt); err != nil {
|
||||
log.Println("Scan error:", err)
|
||||
continue
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(messages)
|
||||
}
|
||||
Reference in New Issue
Block a user