在使用大模型应用时,生成的内容往往是边生成边播放。如果我们自己的应用,我们可以使用双向流式的tts。我在使用火山引擎的双向流语音合成,官方没有提供python版本的demo,且官方文档实际上表述的并不清晰,所以我在阅读go语言版本后,自己写了一个提供给大家。

官方文档 https://www.volcengine.com/docs/6561/1329505

代码

需要自行替换 APP_KEY、VOL_TTS_SPEAKER和ACCESS_KEY

import json
import struct
import logging
from enum import IntEnum

"""
协议文档: https://www.volcengine.com/docs/6561/1329505
"""

logger = logging.getLogger(__name__)


class Event(IntEnum):
    NONE = 0
    START_CONNECTION = 1
    FINISH_CONNECTION = 2
    CONNECTION_STARTED = 50
    CONNECTION_FAILED = 51
    CONNECTION_FINISHED = 52
    START_SESSION = 100
    FINISH_SESSION = 102
    SESSION_STARTED = 150
    SESSION_FINISHED = 152
    SESSION_FAILED = 153
    TASK_REQUEST = 200
    TTS_SENTENCE_START = 350
    TTS_SENTENCE_END = 351
    TTS_RESPONSE = 352


def start_connection_frame():
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.START_CONNECTION))  # event_type
    payload = json.dumps({}).encode()
    payload_len = struct.pack(">I", len(payload))
    return bytes(frame + payload_len + payload)


def finish_connection_frame() -> bytes:
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.FINISH_CONNECTION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(b"{}")))  # payload_len
    frame.extend(b"{}")  # payload
    return bytes(frame)


def start_session_frame(session_id: str, speaker: str, speech_rate: float = 2.0):
    b_meta_data_json = json.dumps(
        {
            "event": 100,
            "req_params": {
                "speaker": speaker,
                "audio_params": {
                    "format": "mp3",
                    "speech_rate": speech_rate,
                },
            },
        },
        ensure_ascii=False,
    ).encode()
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.START_SESSION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(session_id.encode())))
    frame.extend(session_id.encode())  # session_id

    # meta_data_len
    frame.extend(struct.pack(">I", len(b_meta_data_json)))
    frame.extend(b_meta_data_json)
    return bytes(frame)


def finish_session_frame(session_id: str):
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.FINISH_SESSION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(session_id.encode())))
    frame.extend(session_id.encode())  # session_id
    frame.extend(struct.pack(">I", len(b"{}")))  # payload_len
    frame.extend(b"{}")  # payload
    return bytes(frame)


def send_task_frame(chunk: str, session_id: str):
    b_chunk_json = json.dumps(
        {
            "event": Event.TASK_REQUEST,
            "req_params": {
                "text": chunk,
            },
        }
    ).encode()
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.TASK_REQUEST))  # event_type
    session_id_bytes = session_id.encode()
    session_id_len = struct.pack(">I", len(session_id_bytes))
    frame.extend(session_id_len)
    frame.extend(session_id_bytes)
    frame.extend(struct.pack(">I", len(b_chunk_json)))
    frame.extend(b_chunk_json)
    return bytes(frame)


def parse_frame(frame):
    if not isinstance(frame, bytes):
        raise ValueError(f"frame is not bytes: {frame}")

    header = frame[:4]
    version = header[0] >> 4
    header_size = (header[0] & 0x0F) * 4
    message_type = header[1] >> 4
    flags = header[1] & 0x0F
    serialization_method = header[2] >> 4
    compression_method = header[2] & 0x0F

    event = struct.unpack(">I", frame[4:8])[0]

    payload_start = header_size
    if flags & 0x04:  # Check if event number is present
        payload_start += 4

    if message_type in [0b0001, 0b1001, 0b1011]:  # Full request/response or Audio-only
        session_id_len = struct.unpack(">I", frame[payload_start : payload_start + 4])[
            0
        ]
        session_id = frame[
            payload_start + 4 : payload_start + 4 + session_id_len
        ].decode()
        payload_start += 4 + session_id_len
    else:
        session_id = None

    payload_len = struct.unpack(">I", frame[payload_start : payload_start + 4])[0]
    payload = frame[payload_start + 4 : payload_start + 4 + payload_len]

    return {
        "version": version,
        "message_type": message_type,
        "serialization_method": serialization_method,
        "compression_method": compression_method,
        "event": event,
        "session_id": session_id,
        "payload": payload,
    }


import asyncio
import uuid
from collections.abc import AsyncGenerator

import websockets



class VolTtsClient:
    DEFAULT_API_ENDPOINT = "wss://openspeech.bytedance.com/api/v3/tts/bidirection"

    def get_headers(self):
        return {
            "X-Api-App-Key": "YOUR_APP_KEY",
            "X-Api-Access-Key": "YOUR_ACCESS_KEY",
            "X-Api-Resource-Id": "volc.service_type.10029",
            "X-Api-Request-Id": str(uuid.uuid1()),
        }

    async def send_task(
        self,
        session_id: str,
        text_generator: AsyncGenerator[str, None],
        ws,
    ):
        async for chunk in text_generator:
            task_frame = send_task_frame(
                chunk=chunk, session_id=session_id
            )
            await ws.send(task_frame)

        await ws.send(finish_session_frame(session_id))

    async def receive_response(self, ws):
        while True:
            response = await ws.recv()
            frame = parse_frame(response)
            match frame["event"]:
                case Event.TTS_RESPONSE:
                    yield frame["payload"]
                case (
                    Event.SESSION_FINISHED
                    | Event.FINISH_CONNECTION
                ):
                    break

    async def a_duplex_tts(
        self,
        message_id: str,
        text_generator: AsyncGenerator[str, None],
        speaker: str = "zh_female_shuangkuaisisi_moon_bigtts",
    ) -> AsyncGenerator[bytes, None]:
        async with websockets.connect(
            self.DEFAULT_API_ENDPOINT,
            additional_headers=self.get_headers(),
            # ping_interval=20,
        ) as ws:
            try:
                await ws.send(start_connection_frame())
                response = await ws.recv()
                logger.debug(parse_frame(response))

                start_session = start_session_frame(
                    session_id=message_id,
                    speaker=speaker,
                )
                await ws.send(start_session)
                response = await ws.recv()
                logger.debug(parse_frame(response))

                send_task = asyncio.create_task(
                    self.send_task(message_id, text_generator, ws)
                )

                async for audio_chunk in self.receive_response(ws):
                    yield audio_chunk

                # wait for send task to finish
                await send_task
                await ws.send(finish_session_frame(message_id))
                await ws.send(finish_connection_frame())

            except Exception as e:
                logger.error(e, exc_info=True)

test

from typing import AsyncGenerator
import pytest
import TtsClient
from langchain_openai import ChatOpenAI


llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.1,
)


@pytest.mark.asyncio
async def test_run():
    client = TtsClient()

    async def a_text_generator() -> AsyncGenerator[str, None]:
        async for chunk in llm.astream("你好"):
            yield str(chunk.content)

    combined_audio = bytearray()
    async for chunk in client.a_duplex_tts(
        message_id="test_session_id", text_generator=a_text_generator()
    ):
        combined_audio.extend(chunk)
    
    with open("combined_audio.wav", "wb") as audio_file:
        audio_file.write(combined_audio)

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐