#!/usr/bin/env python3 """ Improved Real-time Speech-to-Text WebSocket Server using Vosk with better audio format handling """ import asyncio import websockets import json import logging import subprocess import tempfile import os import io from vosk import Model, KaldiRecognizer # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class ImprovedVoskSTTServer: def __init__(self, model_path="vosk-model-small-en-us-0.15", sample_rate=16000): """ Initialize Improved Vosk STT Server with FFmpeg support Args: model_path: Path to Vosk model directory sample_rate: Audio sample rate (16000 is recommended) """ self.model_path = model_path self.sample_rate = sample_rate self.model = None self.check_dependencies() self.load_model() def check_dependencies(self): """Check if FFmpeg is available""" try: subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True) logger.info("FFmpeg is available") self.has_ffmpeg = True except (subprocess.CalledProcessError, FileNotFoundError): logger.warning("FFmpeg not found. Audio conversion may be limited.") self.has_ffmpeg = False def load_model(self): """Load Vosk model""" try: if not os.path.exists(self.model_path): logger.error(f"Model path {self.model_path} does not exist!") self.print_model_setup_instructions() raise FileNotFoundError(f"Model not found at {self.model_path}") logger.info(f"Loading Vosk model from {self.model_path}...") self.model = Model(self.model_path) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") raise def print_model_setup_instructions(self): """Print instructions for setting up Vosk model""" logger.info("=" * 60) logger.info("VOSK MODEL SETUP INSTRUCTIONS") logger.info("=" * 60) logger.info("1. Download a Vosk model (choose based on your needs):") logger.info("") logger.info(" Small English model (~50MB):") logger.info(" wget https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip") logger.info(" unzip vosk-model-small-en-us-0.15.zip") logger.info("") logger.info(" Large English model (~1.8GB, better accuracy):") logger.info(" wget https://alphacephei.com/vosk/models/vosk-model-en-us-0.22.zip") logger.info(" unzip vosk-model-en-us-0.22.zip") logger.info("") logger.info(" Other languages available at: https://alphacephei.com/vosk/models") logger.info("") logger.info("2. Place the extracted model directory in the server folder") logger.info("3. Update the model path when starting the server") logger.info("=" * 60) async def handle_client(self, websocket, path): """Handle WebSocket client connection""" client_ip = websocket.remote_address[0] logger.info(f"New client connected: {client_ip}") # Create recognizer for this client recognizer = KaldiRecognizer(self.model, self.sample_rate) try: await websocket.send(json.dumps({ "type": "status", "message": "Connected to Vosk STT Server", "server_info": { "sample_rate": self.sample_rate, "has_ffmpeg": self.has_ffmpeg, "model_path": self.model_path } })) async for message in websocket: try: # Handle binary audio data if isinstance(message, bytes): await self.process_audio_chunk(websocket, recognizer, message) # Handle text messages (commands, etc.) elif isinstance(message, str): await self.handle_text_message(websocket, recognizer, message) except Exception as e: logger.error(f"Error processing message: {e}") await websocket.send(json.dumps({ "type": "error", "message": str(e) })) except websockets.exceptions.ConnectionClosed: logger.info(f"Client disconnected: {client_ip}") except Exception as e: logger.error(f"Error handling client {client_ip}: {e}") async def process_audio_chunk(self, websocket, recognizer, audio_data): """Process incoming audio chunk with Vosk""" try: # Convert audio to PCM format for Vosk pcm_data = await self.convert_to_pcm(audio_data) if pcm_data: # Feed audio to recognizer if recognizer.AcceptWaveform(pcm_data): # Final result result = json.loads(recognizer.Result()) if result.get('text', '').strip(): await websocket.send(json.dumps({ "type": "transcription", "text": result['text'], "final": True, "confidence": result.get('confidence', 0.0), "timestamp": asyncio.get_event_loop().time() })) logger.info(f"Final: {result['text']}") else: # Partial result partial_result = json.loads(recognizer.PartialResult()) if partial_result.get('partial', '').strip(): await websocket.send(json.dumps({ "type": "transcription", "text": partial_result['partial'], "final": False, "confidence": 0.0, "timestamp": asyncio.get_event_loop().time() })) except Exception as e: logger.error(f"Error processing audio: {e}") async def convert_to_pcm(self, audio_data): """ Convert various audio formats to PCM format using FFmpeg """ if not self.has_ffmpeg: # Fallback: assume audio is already in compatible format return audio_data try: # Create temporary files with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as input_file: input_file.write(audio_data) input_path = input_file.name with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as output_file: output_path = output_file.name # Use FFmpeg to convert to PCM WAV format cmd = [ 'ffmpeg', '-i', input_path, '-acodec', 'pcm_s16le', # 16-bit PCM '-ac', '1', # Mono '-ar', str(self.sample_rate), # Sample rate '-f', 'wav', '-y', # Overwrite output output_path ] # Run conversion asynchronously process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() if process.returncode == 0: # Read converted audio with open(output_path, 'rb') as f: wav_data = f.read() # Extract PCM data (skip WAV header - 44 bytes) pcm_data = wav_data[44:] # Cleanup os.unlink(input_path) os.unlink(output_path) return pcm_data else: logger.error(f"FFmpeg conversion failed: {stderr.decode()}") # Cleanup os.unlink(input_path) if os.path.exists(output_path): os.unlink(output_path) return None except Exception as e: logger.error(f"Audio conversion error: {e}") return None async def handle_text_message(self, websocket, recognizer, message): """Handle text-based commands from client""" try: data = json.loads(message) command = data.get('command') if command == 'ping': await websocket.send(json.dumps({ "type": "pong", "timestamp": asyncio.get_event_loop().time() })) elif command == 'reset': # Reset recognizer recognizer.Reset() await websocket.send(json.dumps({ "type": "status", "message": "Recognizer reset" })) elif command == 'get_info': await websocket.send(json.dumps({ "type": "server_info", "sample_rate": self.sample_rate, "has_ffmpeg": self.has_ffmpeg, "model_path": self.model_path })) except json.JSONDecodeError: logger.error("Invalid JSON message received") async def start_server(self, host="0.0.0.0", port=5000): """Start the WebSocket server""" logger.info(f"Starting Vosk STT WebSocket server on {host}:{port}") logger.info(f"Using model: {self.model_path}") logger.info(f"Sample rate: {self.sample_rate}") logger.info(f"FFmpeg available: {self.has_ffmpeg}") try: async with websockets.serve(self.handle_client, host, port): logger.info("Server started successfully!") logger.info("Waiting for client connections...") logger.info("Press Ctrl+C to stop the server") # Keep server running await asyncio.Future() # run forever except Exception as e: logger.error(f"Server error: {e}") raise def main(): """Main entry point""" import argparse parser = argparse.ArgumentParser(description='Improved Vosk STT WebSocket Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') parser.add_argument('--port', type=int, default=8765, help='Port to bind to') parser.add_argument('--model', default='vosk-model-small-en-us-0.15', help='Path to Vosk model directory') parser.add_argument('--sample-rate', type=int, default=16000, help='Audio sample rate') args = parser.parse_args() try: # Create and start server server = ImprovedVoskSTTServer(model_path=args.model, sample_rate=args.sample_rate) asyncio.run(server.start_server(host=args.host, port=args.port)) except KeyboardInterrupt: logger.info("Server stopped by user") except Exception as e: logger.error(f"Server failed to start: {e}") return 1 return 0 if __name__ == "__main__": exit(main())