diff --git a/sdk/python/src/datastar_py/fasthtml.py b/sdk/python/src/datastar_py/fasthtml.py
index 2ba32078..6d7701da 100644
--- a/sdk/python/src/datastar_py/fasthtml.py
+++ b/sdk/python/src/datastar_py/fasthtml.py
@@ -1,16 +1,2 @@
-from typing import override
-
-from fastcore.xml import to_xml
-
-from .sse import SSE_HEADERS, ServerSentEventGenerator
-from .starlette import DatastarStreamingResponse as _DatastarStreamingResponse
-
-
-class DatastarStreamingResponse(_DatastarStreamingResponse):
- @classmethod
- @override
- def merge_fragments(cls, fragments, *args, **kwargs):
- if not isinstance(fragments, str):
- fragments = to_xml(fragments)
- # From here, business as usual
- return super().merge_fragments(fragments, *args, **kwargs)
+from .sse import ServerSentEventGenerator
+from .starlette import DatastarStreamingResponse
diff --git a/sdk/python/src/datastar_py/sse.py b/sdk/python/src/datastar_py/sse.py
index 73c7d20d..5ec9141d 100644
--- a/sdk/python/src/datastar_py/sse.py
+++ b/sdk/python/src/datastar_py/sse.py
@@ -1,6 +1,6 @@
import json
from itertools import chain
-from typing import Optional
+from typing import Optional, Protocol, Union, runtime_checkable
import datastar_py.consts as consts
@@ -11,6 +11,17 @@ SSE_HEADERS = {
}
+@runtime_checkable
+class _HtmlProvider(Protocol):
+ """A type that produces text ready to be placed in an HTML document.
+
+ This is a convention used by html producing/consuming libraries. This lets
+ e.g. fasthtml fasttags, or htpy elements, be passed straight in to
+ merge_fragments."""
+
+ def __html__(self) -> str: ...
+
+
class ServerSentEventGenerator:
__slots__ = ()
@@ -38,13 +49,15 @@ class ServerSentEventGenerator:
@classmethod
def merge_fragments(
cls,
- fragments: str,
+ fragments: Union[str, _HtmlProvider],
selector: Optional[str] = None,
merge_mode: Optional[consts.FragmentMergeMode] = None,
use_view_transition: bool = consts.DEFAULT_FRAGMENTS_USE_VIEW_TRANSITIONS,
event_id: Optional[int] = None,
retry_duration: int = consts.DEFAULT_SSE_RETRY_DURATION,
):
+ if isinstance(fragments, _HtmlProvider):
+ fragments = fragments.__html__()
data_lines = []
if merge_mode:
data_lines.append(f"data: {consts.MERGE_MODE_DATALINE_LITERAL} {merge_mode}")