306 lines
11 KiB
Python
306 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Improved Real-time Speech-to-Text WebSocket Server using Vosk
|
|
with better audio format handling
|
|
"""
|
|
|
|
import asyncio
|
|
import websockets
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import tempfile
|
|
import os
|
|
import io
|
|
from vosk import Model, KaldiRecognizer
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ImprovedVoskSTTServer:
|
|
def __init__(self, model_path="vosk-model-small-en-us-0.15", sample_rate=16000):
|
|
"""
|
|
Initialize Improved Vosk STT Server with FFmpeg support
|
|
|
|
Args:
|
|
model_path: Path to Vosk model directory
|
|
sample_rate: Audio sample rate (16000 is recommended)
|
|
"""
|
|
self.model_path = model_path
|
|
self.sample_rate = sample_rate
|
|
self.model = None
|
|
self.check_dependencies()
|
|
self.load_model()
|
|
|
|
def check_dependencies(self):
|
|
"""Check if FFmpeg is available"""
|
|
try:
|
|
subprocess.run(['ffmpeg', '-version'],
|
|
capture_output=True, check=True)
|
|
logger.info("FFmpeg is available")
|
|
self.has_ffmpeg = True
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
logger.warning("FFmpeg not found. Audio conversion may be limited.")
|
|
self.has_ffmpeg = False
|
|
|
|
def load_model(self):
|
|
"""Load Vosk model"""
|
|
try:
|
|
if not os.path.exists(self.model_path):
|
|
logger.error(f"Model path {self.model_path} does not exist!")
|
|
self.print_model_setup_instructions()
|
|
raise FileNotFoundError(f"Model not found at {self.model_path}")
|
|
|
|
logger.info(f"Loading Vosk model from {self.model_path}...")
|
|
self.model = Model(self.model_path)
|
|
logger.info("Model loaded successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
raise
|
|
|
|
def print_model_setup_instructions(self):
|
|
"""Print instructions for setting up Vosk model"""
|
|
logger.info("=" * 60)
|
|
logger.info("VOSK MODEL SETUP INSTRUCTIONS")
|
|
logger.info("=" * 60)
|
|
logger.info("1. Download a Vosk model (choose based on your needs):")
|
|
logger.info("")
|
|
logger.info(" Small English model (~50MB):")
|
|
logger.info(" wget https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip")
|
|
logger.info(" unzip vosk-model-small-en-us-0.15.zip")
|
|
logger.info("")
|
|
logger.info(" Large English model (~1.8GB, better accuracy):")
|
|
logger.info(" wget https://alphacephei.com/vosk/models/vosk-model-en-us-0.22.zip")
|
|
logger.info(" unzip vosk-model-en-us-0.22.zip")
|
|
logger.info("")
|
|
logger.info(" Other languages available at: https://alphacephei.com/vosk/models")
|
|
logger.info("")
|
|
logger.info("2. Place the extracted model directory in the server folder")
|
|
logger.info("3. Update the model path when starting the server")
|
|
logger.info("=" * 60)
|
|
|
|
async def handle_client(self, websocket, path):
|
|
"""Handle WebSocket client connection"""
|
|
client_ip = websocket.remote_address[0]
|
|
logger.info(f"New client connected: {client_ip}")
|
|
|
|
# Create recognizer for this client
|
|
recognizer = KaldiRecognizer(self.model, self.sample_rate)
|
|
|
|
try:
|
|
await websocket.send(json.dumps({
|
|
"type": "status",
|
|
"message": "Connected to Vosk STT Server",
|
|
"server_info": {
|
|
"sample_rate": self.sample_rate,
|
|
"has_ffmpeg": self.has_ffmpeg,
|
|
"model_path": self.model_path
|
|
}
|
|
}))
|
|
|
|
async for message in websocket:
|
|
try:
|
|
# Handle binary audio data
|
|
if isinstance(message, bytes):
|
|
await self.process_audio_chunk(websocket, recognizer, message)
|
|
|
|
# Handle text messages (commands, etc.)
|
|
elif isinstance(message, str):
|
|
await self.handle_text_message(websocket, recognizer, message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing message: {e}")
|
|
await websocket.send(json.dumps({
|
|
"type": "error",
|
|
"message": str(e)
|
|
}))
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info(f"Client disconnected: {client_ip}")
|
|
except Exception as e:
|
|
logger.error(f"Error handling client {client_ip}: {e}")
|
|
|
|
async def process_audio_chunk(self, websocket, recognizer, audio_data):
|
|
"""Process incoming audio chunk with Vosk"""
|
|
try:
|
|
# Convert audio to PCM format for Vosk
|
|
pcm_data = await self.convert_to_pcm(audio_data)
|
|
|
|
if pcm_data:
|
|
# Feed audio to recognizer
|
|
if recognizer.AcceptWaveform(pcm_data):
|
|
# Final result
|
|
result = json.loads(recognizer.Result())
|
|
if result.get('text', '').strip():
|
|
await websocket.send(json.dumps({
|
|
"type": "transcription",
|
|
"text": result['text'],
|
|
"final": True,
|
|
"confidence": result.get('confidence', 0.0),
|
|
"timestamp": asyncio.get_event_loop().time()
|
|
}))
|
|
logger.info(f"Final: {result['text']}")
|
|
else:
|
|
# Partial result
|
|
partial_result = json.loads(recognizer.PartialResult())
|
|
if partial_result.get('partial', '').strip():
|
|
await websocket.send(json.dumps({
|
|
"type": "transcription",
|
|
"text": partial_result['partial'],
|
|
"final": False,
|
|
"confidence": 0.0,
|
|
"timestamp": asyncio.get_event_loop().time()
|
|
}))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio: {e}")
|
|
|
|
async def convert_to_pcm(self, audio_data):
|
|
"""
|
|
Convert various audio formats to PCM format using FFmpeg
|
|
"""
|
|
if not self.has_ffmpeg:
|
|
# Fallback: assume audio is already in compatible format
|
|
return audio_data
|
|
|
|
try:
|
|
# Create temporary files
|
|
with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as input_file:
|
|
input_file.write(audio_data)
|
|
input_path = input_file.name
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as output_file:
|
|
output_path = output_file.name
|
|
|
|
# Use FFmpeg to convert to PCM WAV format
|
|
cmd = [
|
|
'ffmpeg',
|
|
'-i', input_path,
|
|
'-acodec', 'pcm_s16le', # 16-bit PCM
|
|
'-ac', '1', # Mono
|
|
'-ar', str(self.sample_rate), # Sample rate
|
|
'-f', 'wav',
|
|
'-y', # Overwrite output
|
|
output_path
|
|
]
|
|
|
|
# Run conversion asynchronously
|
|
process = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE
|
|
)
|
|
|
|
stdout, stderr = await process.communicate()
|
|
|
|
if process.returncode == 0:
|
|
# Read converted audio
|
|
with open(output_path, 'rb') as f:
|
|
wav_data = f.read()
|
|
|
|
# Extract PCM data (skip WAV header - 44 bytes)
|
|
pcm_data = wav_data[44:]
|
|
|
|
# Cleanup
|
|
os.unlink(input_path)
|
|
os.unlink(output_path)
|
|
|
|
return pcm_data
|
|
else:
|
|
logger.error(f"FFmpeg conversion failed: {stderr.decode()}")
|
|
# Cleanup
|
|
os.unlink(input_path)
|
|
if os.path.exists(output_path):
|
|
os.unlink(output_path)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Audio conversion error: {e}")
|
|
return None
|
|
|
|
async def handle_text_message(self, websocket, recognizer, message):
|
|
"""Handle text-based commands from client"""
|
|
try:
|
|
data = json.loads(message)
|
|
command = data.get('command')
|
|
|
|
if command == 'ping':
|
|
await websocket.send(json.dumps({
|
|
"type": "pong",
|
|
"timestamp": asyncio.get_event_loop().time()
|
|
}))
|
|
|
|
elif command == 'reset':
|
|
# Reset recognizer
|
|
recognizer.Reset()
|
|
await websocket.send(json.dumps({
|
|
"type": "status",
|
|
"message": "Recognizer reset"
|
|
}))
|
|
|
|
elif command == 'get_info':
|
|
await websocket.send(json.dumps({
|
|
"type": "server_info",
|
|
"sample_rate": self.sample_rate,
|
|
"has_ffmpeg": self.has_ffmpeg,
|
|
"model_path": self.model_path
|
|
}))
|
|
|
|
except json.JSONDecodeError:
|
|
logger.error("Invalid JSON message received")
|
|
|
|
async def start_server(self, host="0.0.0.0", port=5000):
|
|
"""Start the WebSocket server"""
|
|
logger.info(f"Starting Vosk STT WebSocket server on {host}:{port}")
|
|
logger.info(f"Using model: {self.model_path}")
|
|
logger.info(f"Sample rate: {self.sample_rate}")
|
|
logger.info(f"FFmpeg available: {self.has_ffmpeg}")
|
|
|
|
try:
|
|
async with websockets.serve(self.handle_client, host, port):
|
|
logger.info("Server started successfully!")
|
|
logger.info("Waiting for client connections...")
|
|
logger.info("Press Ctrl+C to stop the server")
|
|
|
|
# Keep server running
|
|
await asyncio.Future() # run forever
|
|
|
|
except Exception as e:
|
|
logger.error(f"Server error: {e}")
|
|
raise
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Improved Vosk STT WebSocket Server')
|
|
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
|
|
parser.add_argument('--port', type=int, default=8765, help='Port to bind to')
|
|
parser.add_argument('--model', default='vosk-model-small-en-us-0.15',
|
|
help='Path to Vosk model directory')
|
|
parser.add_argument('--sample-rate', type=int, default=16000,
|
|
help='Audio sample rate')
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
# Create and start server
|
|
server = ImprovedVoskSTTServer(model_path=args.model, sample_rate=args.sample_rate)
|
|
asyncio.run(server.start_server(host=args.host, port=args.port))
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Server stopped by user")
|
|
except Exception as e:
|
|
logger.error(f"Server failed to start: {e}")
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|