Allow objects with the __html__ protocol to be used for merge_fragments

This commit is contained in:
Chase Sterling 2025-04-13 23:07:10 -04:00
parent a4145207f0
commit 5a6155b8ed
2 changed files with 12 additions and 18 deletions

View File

@ -1,16 +1,2 @@
from typing import override
from fastcore.xml import to_xml
from .sse import SSE_HEADERS, ServerSentEventGenerator
from .starlette import DatastarStreamingResponse as _DatastarStreamingResponse
class DatastarStreamingResponse(_DatastarStreamingResponse):
@classmethod
@override
def merge_fragments(cls, fragments, *args, **kwargs):
if not isinstance(fragments, str):
fragments = to_xml(fragments)
# From here, business as usual
return super().merge_fragments(fragments, *args, **kwargs)
from .sse import ServerSentEventGenerator
from .starlette import DatastarStreamingResponse

View File

@ -1,6 +1,6 @@
import json
from itertools import chain
from typing import Optional
from typing import Optional, Protocol, Union, runtime_checkable
import datastar_py.consts as consts
@ -11,6 +11,12 @@ SSE_HEADERS = {
}
@runtime_checkable
class _HasHtml(Protocol):
def __html__(self) -> str:
...
class ServerSentEventGenerator:
__slots__ = ()
@ -38,13 +44,15 @@ class ServerSentEventGenerator:
@classmethod
def merge_fragments(
cls,
fragments: str,
fragments: Union[str, _HasHtml],
selector: Optional[str] = None,
merge_mode: Optional[consts.FragmentMergeMode] = None,
use_view_transition: bool = consts.DEFAULT_FRAGMENTS_USE_VIEW_TRANSITIONS,
event_id: Optional[int] = None,
retry_duration: int = consts.DEFAULT_SSE_RETRY_DURATION,
):
if isinstance(fragments, _HasHtml):
fragments = fragments.__html__()
data_lines = []
if merge_mode:
data_lines.append(f"data: {consts.MERGE_MODE_DATALINE_LITERAL} {merge_mode}")