Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions agentlightning/types/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,8 @@

import inspect
import logging
from typing import (
Annotated,
Any,
Dict,
Literal,
Optional,
Union,
)
from pathlib import Path
from typing import Annotated, Any, Dict, Literal, Optional, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -138,20 +132,61 @@ class PromptTemplate(Resource):
engine (Literal['jinja', 'f-string', 'poml']): The templating engine
to use for rendering the prompt. I imagine users can use their own
customized engines, but algos can only well operate on a subset of them.

Notes:
``poml`` templates accept an additional keyword argument ``_poml_format``
when calling :meth:`format`. This value is forwarded to :func:`poml.poml`
and defaults to ``"openai_chat"``.
"""

resource_type: Literal["prompt_template"] = "prompt_template"
template: str
engine: Literal["jinja", "f-string", "poml"]

def format(self, **kwargs: Any) -> str:
"""Format the prompt template with the given kwargs."""
def format(self, **kwargs: Any) -> Any:
"""Format the prompt template with the given kwargs.

Returns:
Any: The rendered prompt. ``f-string`` and ``jinja`` engines return a
string, while ``poml`` returns the object produced by :func:`poml.poml`.

Raises:
RuntimeError: If the required optional dependency for the configured
engine is not available.
"""
if self.engine == "f-string":
return self.template.format(**kwargs)
else:
raise NotImplementedError(
"Formatting prompt templates for non-f-string engines with format() helper is not supported yet."
)
if self.engine == "jinja":
try:
from jinja2 import Template
except ImportError as exc: # pragma: no cover - defensive guard
raise RuntimeError(
"Formatting a PromptTemplate with engine 'jinja' requires the 'jinja2' package to be installed."
) from exc

template = Template(self.template)
return template.render(**kwargs)
if self.engine == "poml":
try:
import poml # type: ignore[import-not-found]
except ImportError as exc: # pragma: no cover - defensive guard
raise RuntimeError(
"Formatting a PromptTemplate with engine 'poml' requires the 'poml' package to be installed."
) from exc

poml_kwargs = dict(kwargs)
poml_format = poml_kwargs.pop("_poml_format", "openai_chat")

template_input: Any
template_path = Path(self.template)
if template_path.suffix == ".poml" and template_path.exists():
template_input = template_path
else:
template_input = self.template

return poml.poml(template_input, context=poml_kwargs, format=poml_format)

raise NotImplementedError(f"Unknown prompt template engine: {self.engine}")


# Use discriminated union for proper deserialization
Expand Down
75 changes: 75 additions & 0 deletions tests/types/test_prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import sys
from pathlib import Path
from types import SimpleNamespace

import pytest

from agentlightning.types import PromptTemplate


def test_prompt_template_format_f_string() -> None:
template = PromptTemplate(template="Hello {name}!", engine="f-string")

assert template.format(name="World") == "Hello World!"


def test_prompt_template_format_jinja(monkeypatch: pytest.MonkeyPatch) -> None:
class DummyTemplate:
def __init__(self, source: str) -> None:
self.source = source

def render(self, **context: str) -> str:
result = self.source
for key, value in context.items():
result = result.replace(f"{{{{ {key} }}}}", value)
return result

monkeypatch.setitem(sys.modules, "jinja2", SimpleNamespace(Template=DummyTemplate))

template = PromptTemplate(template="Hello {{ name }}!", engine="jinja")

assert template.format(name="World") == "Hello World!"


def test_prompt_template_format_poml_inline(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[object, dict[str, object], str]] = []

def dummy_poml(template: object, context: dict[str, object], format: str) -> dict[str, object]:
calls.append((template, context, format))
return {"template": template, "context": context, "format": format}

monkeypatch.setitem(sys.modules, "poml", SimpleNamespace(poml=dummy_poml))

template = PromptTemplate(template="<poml>{{ name }}</poml>", engine="poml")

result = template.format(name="World")

assert calls == [("<poml>{{ name }}</poml>", {"name": "World"}, "openai_chat")]
assert result == {"template": "<poml>{{ name }}</poml>", "context": {"name": "World"}, "format": "openai_chat"}


def test_prompt_template_format_poml_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[object, dict[str, object], str]] = []

def dummy_poml(template: object, context: dict[str, object], format: str) -> dict[str, object]:
calls.append((template, context, format))
return {"template": template, "context": context, "format": format}

monkeypatch.setitem(sys.modules, "poml", SimpleNamespace(poml=dummy_poml))

poml_file = tmp_path / "sample.poml"
poml_file.write_text("<poml>{{ name }}</poml>")

template = PromptTemplate(template=str(poml_file), engine="poml")

result = template.format(name="World", _poml_format="raw")

assert len(calls) == 1
call_template, context, output_format = calls[0]
assert isinstance(call_template, Path)
assert call_template == poml_file
assert context == {"name": "World"}
assert output_format == "raw"
assert result == {"template": poml_file, "context": {"name": "World"}, "format": "raw"}