Merge pull request #658 from pytest-dev/ab/fix-typing

Fix typing
This commit is contained in:
Alessio Bogon 2024-12-05 23:27:38 +01:00 committed by GitHub
commit 9bb4967d9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 327 additions and 154 deletions

View File

@ -20,9 +20,15 @@ Deprecated
Removed
+++++++
* The following private attributes are not available anymore (`#658 <https://github.com/pytest-dev/pytest-bdd/pull/658>`_):
* ``_pytest.reports.TestReport.scenario``; replaced by ``pytest_bdd.reporting.test_report_context`` WeakKeyDictionary (internal use)
* ``__scenario__`` attribute of test functions generated by the ``@scenario`` (and ``@scenarios``) decorator; replaced by ``pytest_bdd.scenario.scenario_wrapper_template_registry`` WeakKeyDictionary (internal use)
* ``_pytest.nodes.Item.__scenario_report__``; replaced by ``pytest_bdd.reporting.scenario_reports_registry`` WeakKeyDictionary (internal use)
* ``_pytest_bdd_step_context`` attribute of internal test function markers; replaced by ``pytest_bdd.steps.step_function_context_registry`` WeakKeyDictionary (internal use)
Fixed
+++++
* Made type annotations stronger and removed most of the ``typing.Any`` usages and ``# type: ignore`` annotations. `#658 <https://github.com/pytest-dev/pytest-bdd/pull/658>`_
Security
++++++++
@ -137,7 +143,7 @@ Fixed
7.0.1
-----
- Fix errors occurring if `pytest_unconfigure` is called before `pytest_configure`. `#362 <https://github.com/pytest-dev/pytest-bdd/issues/362>`_ `#641 <https://github.com/pytest-dev/pytest-bdd/pull/641>`_
- Fix errors occurring if ``pytest_unconfigure`` is called before `pytest_configure`. `#362 <https://github.com/pytest-dev/pytest-bdd/issues/362>`_ `#641 <https://github.com/pytest-dev/pytest-bdd/pull/641>`_
7.0.0
----------

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from collections.abc import Sequence
from importlib.metadata import version
from typing import Any
from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest
from _pytest.nodes import Node
@ -14,10 +13,12 @@ __all__ = ["getfixturedefs", "inject_fixture"]
if pytest_version.release >= (8, 1):
def getfixturedefs(fixturemanager: FixtureManager, fixturename: str, node: Node) -> Sequence[FixtureDef] | None:
def getfixturedefs(
fixturemanager: FixtureManager, fixturename: str, node: Node
) -> Sequence[FixtureDef[object]] | None:
return fixturemanager.getfixturedefs(fixturename, node)
def inject_fixture(request: FixtureRequest, arg: str, value: Any) -> None:
def inject_fixture(request: FixtureRequest, arg: str, value: object) -> None:
"""Inject fixture into pytest fixture request.
:param request: pytest fixture request
@ -38,10 +39,12 @@ if pytest_version.release >= (8, 1):
else:
def getfixturedefs(fixturemanager: FixtureManager, fixturename: str, node: Node) -> Sequence[FixtureDef] | None:
def getfixturedefs(
fixturemanager: FixtureManager, fixturename: str, node: Node
) -> Sequence[FixtureDef[object]] | None:
return fixturemanager.getfixturedefs(fixturename, node.nodeid) # type: ignore
def inject_fixture(request: FixtureRequest, arg: str, value: Any) -> None:
def inject_fixture(request: FixtureRequest, arg: str, value: object) -> None:
"""Inject fixture into pytest fixture request.
:param request: pytest fixture request

View File

@ -6,17 +6,69 @@ import json
import math
import os
import time
import typing
from typing import TYPE_CHECKING, Literal, TypedDict
if typing.TYPE_CHECKING:
from typing import Any
from typing_extensions import NotRequired
from .reporting import FeatureDict, ScenarioReportDict, StepReportDict, test_report_context_registry
if TYPE_CHECKING:
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.reports import TestReport
from _pytest.terminal import TerminalReporter
class ResultElementDict(TypedDict):
status: Literal["passed", "failed", "skipped"]
duration: int # in nanoseconds
error_message: NotRequired[str]
class TagElementDict(TypedDict):
name: str
line: int
class MatchElementDict(TypedDict):
location: str
class StepElementDict(TypedDict):
keyword: str
name: str
line: int
match: MatchElementDict
result: ResultElementDict
class ScenarioElementDict(TypedDict):
keyword: str
id: str
name: str
line: int
description: str
tags: list[TagElementDict]
type: Literal["scenario"]
steps: list[StepElementDict]
class FeatureElementDict(TypedDict):
keyword: str
uri: str
name: str
id: str
line: int
description: str
language: str
tags: list[TagElementDict]
elements: list[ScenarioElementDict]
class FeaturesDict(TypedDict):
features: dict[str, FeatureElementDict]
def add_options(parser: Parser) -> None:
"""Add pytest-bdd options."""
group = parser.getgroup("bdd", "Cucumber JSON")
@ -52,26 +104,32 @@ class LogBDDCucumberJSON:
def __init__(self, logfile: str) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile))
self.features: dict[str, dict] = {}
self.features: dict[str, FeatureElementDict] = {}
def _get_result(self, step: dict[str, Any], report: TestReport, error_message: bool = False) -> dict[str, Any]:
def _get_result(self, step: StepReportDict, report: TestReport, error_message: bool = False) -> ResultElementDict:
"""Get scenario test run result.
:param step: `Step` step we get result for
:param report: pytest `Report` object
:return: `dict` in form {"status": "<passed|failed|skipped>", ["error_message": "<error_message>"]}
"""
result: dict[str, Any] = {}
if report.passed or not step["failed"]: # ignore setup/teardown
result = {"status": "passed"}
elif report.failed:
result = {"status": "failed", "error_message": str(report.longrepr) if error_message else ""}
elif report.skipped:
result = {"status": "skipped"}
result["duration"] = int(math.floor((10**9) * step["duration"])) # nanosec
return result
status: Literal["passed", "failed", "skipped"]
res_message = None
if report.outcome == "passed" or not step["failed"]: # ignore setup/teardown
status = "passed"
elif report.outcome == "failed":
status = "failed"
res_message = str(report.longrepr) if error_message else ""
elif report.outcome == "skipped":
status = "skipped"
else:
raise ValueError(f"Unknown test outcome {report.outcome}")
res: ResultElementDict = {"status": status, "duration": int(math.floor((10**9) * step["duration"]))} # nanosec
if res_message is not None:
res["error_message"] = res_message
return res
def _serialize_tags(self, item: dict[str, Any]) -> list[dict[str, Any]]:
def _serialize_tags(self, item: FeatureDict | ScenarioReportDict) -> list[TagElementDict]:
"""Serialize item's tags.
:param item: json-serialized `Scenario` or `Feature`.
@ -87,8 +145,8 @@ class LogBDDCucumberJSON:
def pytest_runtest_logreport(self, report: TestReport) -> None:
try:
scenario = report.scenario
except AttributeError:
scenario = test_report_context_registry[report].scenario
except KeyError:
# skip reporting for non-bdd tests
return
@ -96,7 +154,7 @@ class LogBDDCucumberJSON:
# skip if there isn't a result or scenario has no steps
return
def stepmap(step: dict[str, Any]) -> dict[str, Any]:
def stepmap(step: StepReportDict) -> StepElementDict:
error_message = False
if step["failed"] and not scenario.setdefault("failed", False):
scenario["failed"] = True
@ -128,7 +186,7 @@ class LogBDDCucumberJSON:
self.features[scenario["feature"]["filename"]]["elements"].append(
{
"keyword": scenario["keyword"],
"id": report.item["name"],
"id": test_report_context_registry[report].name,
"name": scenario["name"],
"line": scenario["line_number"],
"description": scenario["description"],

View File

@ -28,6 +28,7 @@ from __future__ import annotations
import glob
import os.path
from collections.abc import Iterable
from .parser import Feature, FeatureParser
@ -57,7 +58,7 @@ def get_feature(base_path: str, filename: str, encoding: str = "utf-8") -> Featu
return feature
def get_features(paths: list[str], encoding: str = "utf-8") -> list[Feature]:
def get_features(paths: Iterable[str], encoding: str = "utf-8") -> list[Feature]:
"""Get features for given paths.
:param list paths: `list` of paths (file or dirs)

View File

@ -7,24 +7,30 @@ import os.path
from typing import TYPE_CHECKING, cast
from _pytest._io import TerminalWriter
from _pytest.python import Function
from mako.lookup import TemplateLookup # type: ignore
from .compat import getfixturedefs
from .feature import get_features
from .parser import Feature, ScenarioTemplate, Step
from .scenario import inject_fixturedefs_for_step, make_python_docstring, make_python_name, make_string_literal
from .scenario import (
inject_fixturedefs_for_step,
make_python_docstring,
make_python_name,
make_string_literal,
scenario_wrapper_template_registry,
)
from .steps import get_step_fixture_name
from .types import STEP_TYPES
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef, FixtureManager
from _pytest.main import Session
from _pytest.python import Function
from _pytest.nodes import Node
template_lookup = TemplateLookup(directories=[os.path.join(os.path.dirname(__file__), "templates")])
@ -127,15 +133,17 @@ def print_missing_code(scenarios: list[ScenarioTemplate], steps: list[Step]) ->
def _find_step_fixturedef(
fixturemanager: FixtureManager, item: Function, step: Step
) -> Sequence[FixtureDef[Any]] | None:
fixturemanager: FixtureManager, item: Node, step: Step
) -> Sequence[FixtureDef[object]] | None:
"""Find step fixturedef."""
with inject_fixturedefs_for_step(step=step, fixturemanager=fixturemanager, node=item):
bdd_name = get_step_fixture_name(step=step)
return getfixturedefs(fixturemanager, bdd_name, item)
def parse_feature_files(paths: list[str], **kwargs: Any) -> tuple[list[Feature], list[ScenarioTemplate], list[Step]]:
def parse_feature_files(
paths: list[str], encoding: str = "utf-8"
) -> tuple[list[Feature], list[ScenarioTemplate], list[Step]]:
"""Parse feature files of given paths.
:param paths: `list` of paths (file or dirs)
@ -143,7 +151,7 @@ def parse_feature_files(paths: list[str], **kwargs: Any) -> tuple[list[Feature],
:return: `list` of `tuple` in form:
(`list` of `Feature` objects, `list` of `Scenario` objects, `list` of `Step` objects).
"""
features = get_features(paths, **kwargs)
features = get_features(paths, encoding=encoding)
scenarios = sorted(
itertools.chain.from_iterable(feature.scenarios.values() for feature in features),
key=lambda scenario: (scenario.feature.name or scenario.feature.filename, scenario.name),
@ -182,7 +190,9 @@ def _show_missing_code_main(config: Config, session: Session) -> None:
features, scenarios, steps = parse_feature_files(config.option.features)
for item in session.items:
if scenario := getattr(item.obj, "__scenario__", None): # type: ignore
if not isinstance(item, Function):
continue
if (scenario := scenario_wrapper_template_registry.get(item.obj)) is not None:
if scenario in scenarios:
scenarios.remove(scenario)
for step in scenario.steps:

View File

@ -4,9 +4,9 @@ import typing
from _pytest.terminal import TerminalReporter
if typing.TYPE_CHECKING:
from typing import Any
from .reporting import test_report_context_registry
if typing.TYPE_CHECKING:
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.reports import TestReport
@ -43,12 +43,12 @@ def configure(config: Config) -> None:
raise Exception("gherkin-terminal-reporter is not compatible with 'xdist' plugin.")
class GherkinTerminalReporter(TerminalReporter): # type: ignore
class GherkinTerminalReporter(TerminalReporter): # type: ignore[misc]
def __init__(self, config: Config) -> None:
super().__init__(config)
self.current_rule = None
self.current_rule: str | None = None
def pytest_runtest_logreport(self, report: TestReport) -> Any:
def pytest_runtest_logreport(self, report: TestReport) -> None:
rep = report
res = self.config.hook.pytest_report_teststatus(report=rep, config=self.config)
cat, letter, word = res
@ -69,16 +69,21 @@ class GherkinTerminalReporter(TerminalReporter): # type: ignore
scenario_markup = word_markup
rule_markup = {"purple": True}
if self.verbosity <= 0 or not hasattr(report, "scenario"):
try:
scenario = test_report_context_registry[report].scenario
except KeyError:
scenario = None
if self.verbosity <= 0 or scenario is None:
return super().pytest_runtest_logreport(rep)
rule = report.scenario.get("rule")
rule = scenario.get("rule")
indent = " " if rule else ""
if self.verbosity == 1:
self.ensure_newline()
self._tw.write(f"{report.scenario['feature']['keyword']}: ", **feature_markup)
self._tw.write(report.scenario["feature"]["name"], **feature_markup)
self._tw.write(f"{scenario['feature']['keyword']}: ", **feature_markup)
self._tw.write(scenario["feature"]["name"], **feature_markup)
self._tw.write("\n")
if rule and rule["name"] != self.current_rule:
@ -87,15 +92,15 @@ class GherkinTerminalReporter(TerminalReporter): # type: ignore
self._tw.write("\n")
self.current_rule = rule["name"]
self._tw.write(f"{indent} {report.scenario['keyword']}: ", **scenario_markup)
self._tw.write(report.scenario["name"], **scenario_markup)
self._tw.write(f"{indent} {scenario['keyword']}: ", **scenario_markup)
self._tw.write(scenario["name"], **scenario_markup)
self._tw.write(" ")
self._tw.write(word, **word_markup)
self._tw.write("\n")
elif self.verbosity > 1:
self.ensure_newline()
self._tw.write(f"{report.scenario['feature']['keyword']}: ", **feature_markup)
self._tw.write(report.scenario["feature"]["name"], **feature_markup)
self._tw.write(f"{scenario['feature']['keyword']}: ", **feature_markup)
self._tw.write(scenario["feature"]["name"], **feature_markup)
self._tw.write("\n")
if rule and rule["name"] != self.current_rule:
@ -104,13 +109,12 @@ class GherkinTerminalReporter(TerminalReporter): # type: ignore
self._tw.write("\n")
self.current_rule = rule["name"]
self._tw.write(f"{indent} {report.scenario['keyword']}: ", **scenario_markup)
self._tw.write(report.scenario["name"], **scenario_markup)
self._tw.write(f"{indent} {scenario['keyword']}: ", **scenario_markup)
self._tw.write(scenario["name"], **scenario_markup)
self._tw.write("\n")
for step in report.scenario["steps"]:
for step in scenario["steps"]:
self._tw.write(f"{indent} {step['keyword']} {step['name']}\n", **scenario_markup)
self._tw.write(f"{indent} {word}", **word_markup)
self._tw.write("\n\n")
self.stats.setdefault(cat, []).append(rep)
return None

View File

@ -7,7 +7,6 @@ import textwrap
from collections import OrderedDict
from collections.abc import Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any
from .exceptions import StepError
from .gherkin_parser import Background as GherkinBackground
@ -65,7 +64,7 @@ class Feature:
scenarios (OrderedDict[str, ScenarioTemplate]): A dictionary of scenarios in the feature.
filename (str): The absolute path of the feature file.
rel_filename (str): The relative path of the feature file.
name (Optional[str]): The name of the feature.
name (str): The name of the feature.
tags (set[str]): A set of tags associated with the feature.
background (Optional[Background]): The background steps for the feature, if any.
line_number (int): The line number where the feature starts in the file.
@ -77,7 +76,7 @@ class Feature:
rel_filename: str
language: str
keyword: str
name: str | None
name: str
tags: set[str]
background: Background | None
line_number: int
@ -117,11 +116,11 @@ class Examples:
"""
self.examples.append([str(value) if value is not None else "" for value in values])
def as_contexts(self) -> Iterable[dict[str, Any]]:
def as_contexts(self) -> Generator[dict[str, str]]:
"""Generate contexts for the examples.
Yields:
Dict[str, Any]: A dictionary mapping parameter names to their values for each example row.
dict[str, str]: A dictionary mapping parameter names to their values for each example row.
"""
for row in self.examples:
assert len(self.example_params) == len(row)
@ -167,7 +166,7 @@ class ScenarioTemplate:
name: str
line_number: int
templated: bool
description: str | None = None
description: str
tags: set[str] = field(default_factory=set)
_steps: list[Step] = field(init=False, default_factory=list)
examples: list[Examples] = field(default_factory=list[Examples])
@ -202,11 +201,11 @@ class ScenarioTemplate:
"""
return self.all_background_steps + self._steps
def render(self, context: Mapping[str, Any]) -> Scenario:
def render(self, context: Mapping[str, object]) -> Scenario:
"""Render the scenario with the given context.
Args:
context (Mapping[str, Any]): The context for rendering steps.
context (Mapping[str, object]): The context for rendering steps.
Returns:
Scenario: A Scenario object with steps rendered based on the context.
@ -255,7 +254,7 @@ class Scenario:
name: str
line_number: int
steps: list[Step]
description: str | None = None
description: str
tags: set[str] = field(default_factory=set)
rule: Rule | None = None
@ -329,7 +328,7 @@ class Step:
Args:
datatable (DataTable): The datatable to render.
context (Mapping[str, Any]): The context for rendering the datatable.
context (Mapping[str, object]): The context for rendering the datatable.
Returns:
datatable (DataTable): The rendered datatable with parameters replaced only if they exist in the context.

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Callable, TypeVar, cast
import pytest
from typing_extensions import ParamSpec
@ -99,8 +99,8 @@ def pytest_bdd_step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
step_func_args: dict,
step_func: Callable[..., object],
step_func_args: dict[str, object],
exception: Exception,
) -> None:
reporting.step_error(request, feature, scenario, step, step_func, step_func_args, exception)
@ -112,7 +112,7 @@ def pytest_bdd_before_step(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
step_func: Callable[..., object],
) -> None:
reporting.before_step(request, feature, scenario, step, step_func)
@ -123,8 +123,8 @@ def pytest_bdd_after_step(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
step_func_args: dict[str, Any],
step_func: Callable[..., object],
step_func_args: dict[str, object],
) -> None:
reporting.after_step(request, feature, scenario, step, step_func, step_func_args)

View File

@ -7,11 +7,13 @@ that enriches the pytest test reporting.
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, TypedDict
from weakref import WeakKeyDictionary
from typing_extensions import NotRequired
if TYPE_CHECKING:
from typing import Any, Callable
from _pytest.fixtures import FixtureRequest
from _pytest.nodes import Item
from _pytest.reports import TestReport
@ -19,12 +21,54 @@ if TYPE_CHECKING:
from .parser import Feature, Scenario, Step
scenario_reports_registry: WeakKeyDictionary[Item, ScenarioReport] = WeakKeyDictionary()
test_report_context_registry: WeakKeyDictionary[TestReport, ReportContext] = WeakKeyDictionary()
class FeatureDict(TypedDict):
keyword: str
name: str
filename: str
rel_filename: str
language: str
line_number: int
description: str
tags: list[str]
class RuleDict(TypedDict):
keyword: str
name: str
description: str
tags: list[str]
class StepReportDict(TypedDict):
name: str
type: str
keyword: str
line_number: int
failed: bool
duration: float
class ScenarioReportDict(TypedDict):
steps: list[StepReportDict]
keyword: str
name: str
line_number: int
tags: list[str]
feature: FeatureDict
description: str
rule: NotRequired[RuleDict]
failed: NotRequired[bool]
class StepReport:
"""Step execution report."""
failed = False
stopped = None
failed: bool = False
stopped: float | None = None
def __init__(self, step: Step) -> None:
"""Step report constructor.
@ -34,11 +78,10 @@ class StepReport:
self.step = step
self.started = time.perf_counter()
def serialize(self) -> dict[str, Any]:
def serialize(self) -> StepReportDict:
"""Serialize the step execution report.
:return: Serialized step execution report.
:rtype: dict
"""
return {
"name": self.step.name,
@ -98,16 +141,15 @@ class ScenarioReport:
"""
self.step_reports.append(step_report)
def serialize(self) -> dict[str, Any]:
def serialize(self) -> ScenarioReportDict:
"""Serialize scenario execution report in order to transfer reporting from nodes in the distributed mode.
:return: Serialized report.
:rtype: dict
"""
scenario = self.scenario
feature = scenario.feature
serialized = {
serialized: ScenarioReportDict = {
"steps": [step_report.serialize() for step_report in self.step_reports],
"keyword": scenario.keyword,
"name": scenario.name,
@ -127,12 +169,13 @@ class ScenarioReport:
}
if scenario.rule:
serialized["rule"] = {
rule_dict: RuleDict = {
"keyword": scenario.rule.keyword,
"name": scenario.rule.name,
"description": scenario.rule.description,
"tags": scenario.rule.tags,
"tags": sorted(scenario.rule.tags),
}
serialized["rule"] = rule_dict
return serialized
@ -148,17 +191,25 @@ class ScenarioReport:
self.add_step_report(report)
@dataclass
class ReportContext:
scenario: ScenarioReportDict
name: str
def runtest_makereport(item: Item, call: CallInfo, rep: TestReport) -> None:
"""Store item in the report object."""
scenario_report = getattr(item, "__scenario_report__", None)
if scenario_report is not None:
rep.scenario = scenario_report.serialize() # type: ignore
rep.item = {"name": item.name} # type: ignore
try:
scenario_report: ScenarioReport = scenario_reports_registry[item]
except KeyError:
return
test_report_context_registry[rep] = ReportContext(scenario=scenario_report.serialize(), name=item.name)
def before_scenario(request: FixtureRequest, feature: Feature, scenario: Scenario) -> None:
"""Create scenario report for the item."""
request.node.__scenario_report__ = ScenarioReport(scenario=scenario)
scenario_reports_registry[request.node] = ScenarioReport(scenario=scenario)
def step_error(
@ -166,12 +217,12 @@ def step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
step_func_args: dict,
step_func: Callable[..., object],
step_func_args: dict[str, object],
exception: Exception,
) -> None:
"""Finalize the step report as failed."""
request.node.__scenario_report__.fail()
scenario_reports_registry[request.node].fail()
def before_step(
@ -179,10 +230,10 @@ def before_step(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
step_func: Callable[..., object],
) -> None:
"""Store step start time."""
request.node.__scenario_report__.add_step_report(StepReport(step=step))
scenario_reports_registry[request.node].add_step_report(StepReport(step=step))
def after_step(
@ -194,4 +245,4 @@ def after_step(
step_func_args: dict,
) -> None:
"""Finalize the step report as successful."""
request.node.__scenario_report__.current_step_report.finalize(failed=False)
scenario_reports_registry[request.node].current_step_report.finalize(failed=False)

View File

@ -19,17 +19,24 @@ import os
import re
from collections.abc import Iterable, Iterator
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Callable, TypeVar, cast
from weakref import WeakKeyDictionary
import pytest
from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func
from typing_extensions import ParamSpec
from . import exceptions
from .compat import getfixturedefs, inject_fixture
from .feature import get_feature, get_features
from .steps import StepFunctionContext, get_step_fixture_name
from .utils import CONFIG_STACK, get_caller_module_locals, get_caller_module_path, get_required_args, identity
from .steps import StepFunctionContext, get_step_fixture_name, step_function_context_registry
from .utils import (
CONFIG_STACK,
get_caller_module_locals,
get_caller_module_path,
get_required_args,
identity,
registry_get_safe,
)
if TYPE_CHECKING:
from _pytest.mark.structures import ParameterSet
@ -37,7 +44,6 @@ if TYPE_CHECKING:
from .parser import Feature, Scenario, ScenarioTemplate, Step
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
@ -49,14 +55,16 @@ STEP_ARGUMENT_DATATABLE = "datatable"
STEP_ARGUMENT_DOCSTRING = "docstring"
STEP_ARGUMENTS_RESERVED_NAMES = {STEP_ARGUMENT_DATATABLE, STEP_ARGUMENT_DOCSTRING}
scenario_wrapper_template_registry: WeakKeyDictionary[Callable[..., object], ScenarioTemplate] = WeakKeyDictionary()
def find_fixturedefs_for_step(step: Step, fixturemanager: FixtureManager, node: Node) -> Iterable[FixtureDef[Any]]:
def find_fixturedefs_for_step(step: Step, fixturemanager: FixtureManager, node: Node) -> Iterable[FixtureDef[object]]:
"""Find the fixture defs that can parse a step."""
# happens to be that _arg2fixturedefs is changed during the iteration so we use a copy
fixture_def_by_name = list(fixturemanager._arg2fixturedefs.items())
for fixturename, fixturedefs in fixture_def_by_name:
for _, fixturedef in enumerate(fixturedefs):
step_func_context = getattr(fixturedef.func, "_pytest_bdd_step_context", None)
step_func_context = step_function_context_registry.get(fixturedef.func)
if step_func_context is None:
continue
@ -67,7 +75,7 @@ def find_fixturedefs_for_step(step: Step, fixturemanager: FixtureManager, node:
if not match:
continue
fixturedefs = cast(list[FixtureDef[Any]], getfixturedefs(fixturemanager, fixturename, node) or [])
fixturedefs = list(getfixturedefs(fixturemanager, fixturename, node) or [])
if fixturedef not in fixturedefs:
continue
@ -278,14 +286,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ
def _get_scenario_decorator(
feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str
) -> Callable[[Callable[P, T]], Callable[P, T]]:
) -> Callable[[Callable[..., T]], Callable[[FixtureRequest, dict[str, str]], T]]:
# HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception
# when the decorator is misused.
# Pytest inspect the signature to determine the required fixtures, and in that case it would look
# for a fixture called "fn" that doesn't exist (if it exists then it's even worse).
# It will error with a "fixture 'fn' not found" message instead.
# We can avoid this hack by using a pytest hook and check for misuse instead.
def decorator(*args: Callable[P, T]) -> Callable[P, T]:
def decorator(*args: Callable[..., T]) -> Callable[[FixtureRequest, dict[str, str]], T]:
if not args:
raise exceptions.ScenarioIsDecoratorOnly(
"scenario function can only be used as a decorator. Refer to the documentation."
@ -293,7 +301,7 @@ def _get_scenario_decorator(
[fn] = args
func_args = get_required_args(fn)
def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str]) -> Any:
def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str]) -> T:
__tracebackhide__ = True
scenario = templated_scenario.render(_pytest_bdd_example)
_execute_scenario(feature, scenario, request)
@ -319,8 +327,9 @@ def _get_scenario_decorator(
config.hook.pytest_bdd_apply_tag(tag=tag, function=scenario_wrapper)
scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}"
scenario_wrapper.__scenario__ = templated_scenario # type: ignore[attr-defined]
return cast(Callable[P, T], scenario_wrapper)
scenario_wrapper_template_registry[scenario_wrapper] = templated_scenario
return scenario_wrapper
return decorator
@ -353,7 +362,7 @@ def scenario(
scenario_name: str,
encoding: str = "utf-8",
features_base_dir: str | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Scenario decorator.
:param str feature_name: Feature file name. Absolute or relative to the configured feature base path.
@ -435,15 +444,17 @@ def get_python_name_generator(name: str) -> Iterable[str]:
suffix = f"_{index}"
def scenarios(*feature_paths: str, **kwargs: Any) -> None:
def scenarios(*feature_paths: str, encoding: str = "utf-8", features_base_dir: str | None = None) -> None:
caller_locals = get_caller_module_locals()
"""Parse features from the paths and put all found scenarios in the caller module.
:param *feature_paths: feature file paths to use for scenarios
:param str encoding: Feature file encoding.
:param features_base_dir: Optional base dir location for locating feature files. If not set, it will try and
resolve using property set in .ini file, otherwise it is assumed to be relative from the caller path location.
"""
caller_locals = get_caller_module_locals()
caller_path = get_caller_module_path()
features_base_dir = kwargs.get("features_base_dir")
if features_base_dir is None:
features_base_dir = get_features_base_dir(caller_path)
@ -455,9 +466,9 @@ def scenarios(*feature_paths: str, **kwargs: Any) -> None:
found = False
module_scenarios = frozenset(
(attr.__scenario__.feature.filename, attr.__scenario__.name)
(s.feature.filename, s.name)
for name, attr in caller_locals.items()
if hasattr(attr, "__scenario__")
if (s := registry_get_safe(scenario_wrapper_template_registry, attr)) is not None
)
for feature in get_features(abs_feature_paths):
@ -465,7 +476,7 @@ def scenarios(*feature_paths: str, **kwargs: Any) -> None:
# skip already bound scenarios
if (scenario_object.feature.filename, scenario_name) not in module_scenarios:
@scenario(feature.filename, scenario_name, **kwargs)
@scenario(feature.filename, scenario_name, encoding=encoding, features_base_dir=features_base_dir)
def _scenario() -> None:
pass # pragma: no cover

View File

@ -41,19 +41,21 @@ import enum
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import count
from typing import Any, Callable, Literal, TypeVar
from typing import Callable, Literal, TypeVar
from weakref import WeakKeyDictionary
import pytest
from typing_extensions import ParamSpec
from .parser import Step
from .parsers import StepParser, get_parser
from .types import GIVEN, THEN, WHEN
from .utils import get_caller_module_locals
P = ParamSpec("P")
T = TypeVar("T")
step_function_context_registry: WeakKeyDictionary[Callable[..., object], StepFunctionContext] = WeakKeyDictionary()
@enum.unique
class StepNamePrefix(enum.Enum):
@ -64,9 +66,9 @@ class StepNamePrefix(enum.Enum):
@dataclass
class StepFunctionContext:
type: Literal["given", "when", "then"] | None
step_func: Callable[..., Any]
step_func: Callable[..., object]
parser: StepParser
converters: dict[str, Callable[[str], Any]] = field(default_factory=dict)
converters: dict[str, Callable[[str], object]] = field(default_factory=dict)
target_fixture: str | None = None
@ -77,7 +79,7 @@ def get_step_fixture_name(step: Step) -> str:
def given(
name: str | StepParser,
converters: dict[str, Callable[[str], Any]] | None = None,
converters: dict[str, Callable[[str], object]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
@ -91,12 +93,12 @@ def given(
:return: Decorator function for the step.
"""
return step(name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(name, "given", converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
def when(
name: str | StepParser,
converters: dict[str, Callable[[str], Any]] | None = None,
converters: dict[str, Callable[[str], object]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
@ -110,12 +112,12 @@ def when(
:return: Decorator function for the step.
"""
return step(name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(name, "when", converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
def then(
name: str | StepParser,
converters: dict[str, Callable[[str], Any]] | None = None,
converters: dict[str, Callable[[str], object]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
@ -129,13 +131,13 @@ def then(
:return: Decorator function for the step.
"""
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(name, "then", converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
def step(
name: str | StepParser,
type_: Literal["given", "when", "then"] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
converters: dict[str, Callable[[str], object]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
@ -172,7 +174,7 @@ def step(
def step_function_marker() -> StepFunctionContext:
return context
step_function_marker._pytest_bdd_step_context = context # type: ignore
step_function_context_registry[step_function_marker] = context
caller_locals = get_caller_module_locals(stacklevel=stacklevel)
fixture_step_name = find_unique_name(

View File

@ -7,20 +7,21 @@ import pickle
import re
from inspect import getframeinfo, signature
from sys import _getframe
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, Callable, TypeVar, cast, overload
from weakref import WeakKeyDictionary
if TYPE_CHECKING:
from typing import Any, Callable
from _pytest.config import Config
from _pytest.pytester import RunResult
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
CONFIG_STACK: list[Config] = []
def get_required_args(func: Callable[..., Any]) -> list[str]:
def get_required_args(func: Callable[..., object]) -> list[str]:
"""Get a list of argument that are required for a function.
:param func: The function to inspect.
@ -33,7 +34,7 @@ def get_required_args(func: Callable[..., Any]) -> list[str]:
]
def get_caller_module_locals(stacklevel: int = 1) -> dict[str, Any]:
def get_caller_module_locals(stacklevel: int = 1) -> dict[str, object]:
"""Get the caller module locals dictionary.
We use sys._getframe instead of inspect.stack(0) because the latter is way slower, since it iterates over
@ -56,7 +57,7 @@ _DUMP_START = "_pytest_bdd_>>>"
_DUMP_END = "<<<_pytest_bdd_"
def dump_obj(*objects: Any) -> None:
def dump_obj(*objects: object) -> None:
"""Dump objects to stdout so that they can be inspected by the test suite."""
for obj in objects:
dump = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
@ -87,3 +88,19 @@ def setdefault(obj: object, name: str, default: T) -> T:
def identity(x: T) -> T:
"""Return the argument."""
return x
@overload
def registry_get_safe(registry: WeakKeyDictionary[K, V], key: object, default: T) -> V | T: ...
@overload
def registry_get_safe(registry: WeakKeyDictionary[K, V], key: object, default: None = None) -> V | None: ...
def registry_get_safe(registry: WeakKeyDictionary[K, V], key: object, default: T | None = None) -> T | V | None:
"""Get a value from a registry, or None if the key is not in the registry.
It ensures that this works even if the key cannot be weak-referenced (normally this would raise a TypeError).
"""
try:
return registry.get(key, default) # type: ignore[arg-type]
except TypeError:
return None

View File

@ -31,27 +31,30 @@ def test_description(pytester):
pytester.makepyfile(
textwrap.dedent(
r'''
import textwrap
from pytest_bdd import given, scenario
import textwrap
from pytest_bdd import given, scenario
from pytest_bdd.scenario import scenario_wrapper_template_registry
@scenario("description.feature", "Description")
def test_description():
pass
@scenario("description.feature", "Description")
def test_description():
pass
@given("I have a bar")
def _():
return "bar"
@given("I have a bar")
def _():
return "bar"
def test_feature_description():
assert test_description.__scenario__.feature.description == textwrap.dedent(
"In order to achieve something\nI want something\nBecause it will be cool\n\n\nSome description goes here."
)
def test_feature_description():
scenario = scenario_wrapper_template_registry[test_description]
assert scenario.feature.description == textwrap.dedent(
"In order to achieve something\nI want something\nBecause it will be cool\n\n\nSome description goes here."
)
def test_scenario_description():
assert test_description.__scenario__.description == textwrap.dedent(
"Also, the scenario can have a description.\n\nIt goes here between the scenario name\nand the first step."""
)
def test_scenario_description():
scenario = scenario_wrapper_template_registry[test_description]
assert scenario.description == textwrap.dedent(
"Also, the scenario can have a description.\n\nIt goes here between the scenario name\nand the first step."""
)
'''
)
)

View File

@ -5,6 +5,8 @@ from typing import Optional
import pytest
from pytest_bdd.reporting import test_report_context_registry
class OfType:
"""Helper object comparison to which is always 'equal'."""
@ -102,7 +104,8 @@ def test_step_trace(pytester):
)
result = pytester.inline_run("-vvl")
assert result.ret
report = result.matchreport("test_passing", when="call").scenario
report = result.matchreport("test_passing", when="call")
scenario = test_report_context_registry[report].scenario
expected = {
"feature": {
"description": "",
@ -139,9 +142,10 @@ def test_step_trace(pytester):
"tags": ["scenario-passing-tag"],
}
assert report == expected
assert scenario == expected
report = result.matchreport("test_failing", when="call").scenario
report = result.matchreport("test_failing", when="call")
scenario = test_report_context_registry[report].scenario
expected = {
"feature": {
"description": "",
@ -177,9 +181,10 @@ def test_step_trace(pytester):
],
"tags": ["scenario-failing-tag"],
}
assert report == expected
assert scenario == expected
report = result.matchreport("test_outlined[12-5-7]", when="call").scenario
report = result.matchreport("test_outlined[12-5-7]", when="call")
scenario = test_report_context_registry[report].scenario
expected = {
"feature": {
"description": "",
@ -223,9 +228,10 @@ def test_step_trace(pytester):
],
"tags": [],
}
assert report == expected
assert scenario == expected
report = result.matchreport("test_outlined[5-4-1]", when="call").scenario
report = result.matchreport("test_outlined[5-4-1]", when="call")
scenario = test_report_context_registry[report].scenario
expected = {
"feature": {
"description": "",
@ -269,7 +275,7 @@ def test_step_trace(pytester):
],
"tags": [],
}
assert report == expected
assert scenario == expected
def test_complex_types(pytester, pytestconfig):
@ -334,5 +340,7 @@ def test_complex_types(pytester, pytestconfig):
result = pytester.inline_run("-vvl")
report = result.matchreport("test_complex[10,20-alien0]", when="call")
assert report.passed
assert execnet.gateway_base.dumps(report.item)
assert execnet.gateway_base.dumps(report.scenario)
report_context = test_report_context_registry[report]
assert execnet.gateway_base.dumps(report_context.name)
assert execnet.gateway_base.dumps(report_context.scenario)