stt-vosk-api/app2.py

306 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Improved Real-time Speech-to-Text WebSocket Server using Vosk
with better audio format handling
"""
import asyncio
import websockets
import json
import logging
import subprocess
import tempfile
import os
import io
from vosk import Model, KaldiRecognizer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ImprovedVoskSTTServer:
def __init__(self, model_path="vosk-model-small-en-us-0.15", sample_rate=16000):
"""
Initialize Improved Vosk STT Server with FFmpeg support
Args:
model_path: Path to Vosk model directory
sample_rate: Audio sample rate (16000 is recommended)
"""
self.model_path = model_path
self.sample_rate = sample_rate
self.model = None
self.check_dependencies()
self.load_model()
def check_dependencies(self):
"""Check if FFmpeg is available"""
try:
subprocess.run(['ffmpeg', '-version'],
capture_output=True, check=True)
logger.info("FFmpeg is available")
self.has_ffmpeg = True
except (subprocess.CalledProcessError, FileNotFoundError):
logger.warning("FFmpeg not found. Audio conversion may be limited.")
self.has_ffmpeg = False
def load_model(self):
"""Load Vosk model"""
try:
if not os.path.exists(self.model_path):
logger.error(f"Model path {self.model_path} does not exist!")
self.print_model_setup_instructions()
raise FileNotFoundError(f"Model not found at {self.model_path}")
logger.info(f"Loading Vosk model from {self.model_path}...")
self.model = Model(self.model_path)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def print_model_setup_instructions(self):
"""Print instructions for setting up Vosk model"""
logger.info("=" * 60)
logger.info("VOSK MODEL SETUP INSTRUCTIONS")
logger.info("=" * 60)
logger.info("1. Download a Vosk model (choose based on your needs):")
logger.info("")
logger.info(" Small English model (~50MB):")
logger.info(" wget https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip")
logger.info(" unzip vosk-model-small-en-us-0.15.zip")
logger.info("")
logger.info(" Large English model (~1.8GB, better accuracy):")
logger.info(" wget https://alphacephei.com/vosk/models/vosk-model-en-us-0.22.zip")
logger.info(" unzip vosk-model-en-us-0.22.zip")
logger.info("")
logger.info(" Other languages available at: https://alphacephei.com/vosk/models")
logger.info("")
logger.info("2. Place the extracted model directory in the server folder")
logger.info("3. Update the model path when starting the server")
logger.info("=" * 60)
async def handle_client(self, websocket, path):
"""Handle WebSocket client connection"""
client_ip = websocket.remote_address[0]
logger.info(f"New client connected: {client_ip}")
# Create recognizer for this client
recognizer = KaldiRecognizer(self.model, self.sample_rate)
try:
await websocket.send(json.dumps({
"type": "status",
"message": "Connected to Vosk STT Server",
"server_info": {
"sample_rate": self.sample_rate,
"has_ffmpeg": self.has_ffmpeg,
"model_path": self.model_path
}
}))
async for message in websocket:
try:
# Handle binary audio data
if isinstance(message, bytes):
await self.process_audio_chunk(websocket, recognizer, message)
# Handle text messages (commands, etc.)
elif isinstance(message, str):
await self.handle_text_message(websocket, recognizer, message)
except Exception as e:
logger.error(f"Error processing message: {e}")
await websocket.send(json.dumps({
"type": "error",
"message": str(e)
}))
except websockets.exceptions.ConnectionClosed:
logger.info(f"Client disconnected: {client_ip}")
except Exception as e:
logger.error(f"Error handling client {client_ip}: {e}")
async def process_audio_chunk(self, websocket, recognizer, audio_data):
"""Process incoming audio chunk with Vosk"""
try:
# Convert audio to PCM format for Vosk
pcm_data = await self.convert_to_pcm(audio_data)
if pcm_data:
# Feed audio to recognizer
if recognizer.AcceptWaveform(pcm_data):
# Final result
result = json.loads(recognizer.Result())
if result.get('text', '').strip():
await websocket.send(json.dumps({
"type": "transcription",
"text": result['text'],
"final": True,
"confidence": result.get('confidence', 0.0),
"timestamp": asyncio.get_event_loop().time()
}))
logger.info(f"Final: {result['text']}")
else:
# Partial result
partial_result = json.loads(recognizer.PartialResult())
if partial_result.get('partial', '').strip():
await websocket.send(json.dumps({
"type": "transcription",
"text": partial_result['partial'],
"final": False,
"confidence": 0.0,
"timestamp": asyncio.get_event_loop().time()
}))
except Exception as e:
logger.error(f"Error processing audio: {e}")
async def convert_to_pcm(self, audio_data):
"""
Convert various audio formats to PCM format using FFmpeg
"""
if not self.has_ffmpeg:
# Fallback: assume audio is already in compatible format
return audio_data
try:
# Create temporary files
with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as input_file:
input_file.write(audio_data)
input_path = input_file.name
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as output_file:
output_path = output_file.name
# Use FFmpeg to convert to PCM WAV format
cmd = [
'ffmpeg',
'-i', input_path,
'-acodec', 'pcm_s16le', # 16-bit PCM
'-ac', '1', # Mono
'-ar', str(self.sample_rate), # Sample rate
'-f', 'wav',
'-y', # Overwrite output
output_path
]
# Run conversion asynchronously
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
# Read converted audio
with open(output_path, 'rb') as f:
wav_data = f.read()
# Extract PCM data (skip WAV header - 44 bytes)
pcm_data = wav_data[44:]
# Cleanup
os.unlink(input_path)
os.unlink(output_path)
return pcm_data
else:
logger.error(f"FFmpeg conversion failed: {stderr.decode()}")
# Cleanup
os.unlink(input_path)
if os.path.exists(output_path):
os.unlink(output_path)
return None
except Exception as e:
logger.error(f"Audio conversion error: {e}")
return None
async def handle_text_message(self, websocket, recognizer, message):
"""Handle text-based commands from client"""
try:
data = json.loads(message)
command = data.get('command')
if command == 'ping':
await websocket.send(json.dumps({
"type": "pong",
"timestamp": asyncio.get_event_loop().time()
}))
elif command == 'reset':
# Reset recognizer
recognizer.Reset()
await websocket.send(json.dumps({
"type": "status",
"message": "Recognizer reset"
}))
elif command == 'get_info':
await websocket.send(json.dumps({
"type": "server_info",
"sample_rate": self.sample_rate,
"has_ffmpeg": self.has_ffmpeg,
"model_path": self.model_path
}))
except json.JSONDecodeError:
logger.error("Invalid JSON message received")
async def start_server(self, host="0.0.0.0", port=5000):
"""Start the WebSocket server"""
logger.info(f"Starting Vosk STT WebSocket server on {host}:{port}")
logger.info(f"Using model: {self.model_path}")
logger.info(f"Sample rate: {self.sample_rate}")
logger.info(f"FFmpeg available: {self.has_ffmpeg}")
try:
async with websockets.serve(self.handle_client, host, port):
logger.info("Server started successfully!")
logger.info("Waiting for client connections...")
logger.info("Press Ctrl+C to stop the server")
# Keep server running
await asyncio.Future() # run forever
except Exception as e:
logger.error(f"Server error: {e}")
raise
def main():
"""Main entry point"""
import argparse
parser = argparse.ArgumentParser(description='Improved Vosk STT WebSocket Server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
parser.add_argument('--port', type=int, default=8765, help='Port to bind to')
parser.add_argument('--model', default='vosk-model-small-en-us-0.15',
help='Path to Vosk model directory')
parser.add_argument('--sample-rate', type=int, default=16000,
help='Audio sample rate')
args = parser.parse_args()
try:
# Create and start server
server = ImprovedVoskSTTServer(model_path=args.model, sample_rate=args.sample_rate)
asyncio.run(server.start_server(host=args.host, port=args.port))
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server failed to start: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())