From 3c5974c2e15d74a2bd2f8326232dbe59d84d6135 Mon Sep 17 00:00:00 2001 From: Chase Sterling Date: Mon, 14 Apr 2025 21:31:20 -0400 Subject: [PATCH] Python SDK: Allow objects with the __html__ protocol to be used for merge_fragments (#837) * Allow objects with the __html__ protocol to be used for merge_fragments * Clarify what HasHtml is for with a docstring --- sdk/python/src/datastar_py/fasthtml.py | 18 ++---------------- sdk/python/src/datastar_py/sse.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/sdk/python/src/datastar_py/fasthtml.py b/sdk/python/src/datastar_py/fasthtml.py index 2ba32078..6d7701da 100644 --- a/sdk/python/src/datastar_py/fasthtml.py +++ b/sdk/python/src/datastar_py/fasthtml.py @@ -1,16 +1,2 @@ -from typing import override - -from fastcore.xml import to_xml - -from .sse import SSE_HEADERS, ServerSentEventGenerator -from .starlette import DatastarStreamingResponse as _DatastarStreamingResponse - - -class DatastarStreamingResponse(_DatastarStreamingResponse): - @classmethod - @override - def merge_fragments(cls, fragments, *args, **kwargs): - if not isinstance(fragments, str): - fragments = to_xml(fragments) - # From here, business as usual - return super().merge_fragments(fragments, *args, **kwargs) +from .sse import ServerSentEventGenerator +from .starlette import DatastarStreamingResponse diff --git a/sdk/python/src/datastar_py/sse.py b/sdk/python/src/datastar_py/sse.py index 73c7d20d..5ec9141d 100644 --- a/sdk/python/src/datastar_py/sse.py +++ b/sdk/python/src/datastar_py/sse.py @@ -1,6 +1,6 @@ import json from itertools import chain -from typing import Optional +from typing import Optional, Protocol, Union, runtime_checkable import datastar_py.consts as consts @@ -11,6 +11,17 @@ SSE_HEADERS = { } +@runtime_checkable +class _HtmlProvider(Protocol): + """A type that produces text ready to be placed in an HTML document. + + This is a convention used by html producing/consuming libraries. This lets + e.g. fasthtml fasttags, or htpy elements, be passed straight in to + merge_fragments.""" + + def __html__(self) -> str: ... + + class ServerSentEventGenerator: __slots__ = () @@ -38,13 +49,15 @@ class ServerSentEventGenerator: @classmethod def merge_fragments( cls, - fragments: str, + fragments: Union[str, _HtmlProvider], selector: Optional[str] = None, merge_mode: Optional[consts.FragmentMergeMode] = None, use_view_transition: bool = consts.DEFAULT_FRAGMENTS_USE_VIEW_TRANSITIONS, event_id: Optional[int] = None, retry_duration: int = consts.DEFAULT_SSE_RETRY_DURATION, ): + if isinstance(fragments, _HtmlProvider): + fragments = fragments.__html__() data_lines = [] if merge_mode: data_lines.append(f"data: {consts.MERGE_MODE_DATALINE_LITERAL} {merge_mode}")