Merge pull request #12496 from RonnyPfannschmidt/ronny/backport-annotations-8.2

backport annotations to 8.2
This commit is contained in:
Ronny Pfannschmidt 2024-06-21 10:25:23 +02:00 committed by GitHub
commit 76065e5028
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
244 changed files with 2040 additions and 1824 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import sys import sys

View File

@ -2,6 +2,8 @@
# 2.7.5 3.3.2 # 2.7.5 3.3.2
# FilesCompleter 75.1109 69.2116 # FilesCompleter 75.1109 69.2116
# FastFilesCompleter 0.7383 1.0760 # FastFilesCompleter 0.7383 1.0760
from __future__ import annotations
import timeit import timeit

View File

@ -1,2 +1,5 @@
from __future__ import annotations
for i in range(1000): for i in range(1000):
exec("def test_func_%d(): pass" % i) exec("def test_func_%d(): pass" % i)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from unittest import TestCase # noqa: F401 from unittest import TestCase # noqa: F401

View File

@ -1,3 +1,6 @@
from __future__ import annotations
for i in range(5000): for i in range(5000):
exec( exec(
f""" f"""

View File

@ -0,0 +1,3 @@
Migrated all internal type-annotations to the python3.10+ style by using the `annotations` future import.
-- by :user:`RonnyPfannschmidt`

View File

@ -15,15 +15,19 @@
# #
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
# The short X.Y version. # The short X.Y version.
from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
import shutil import shutil
from textwrap import dedent from textwrap import dedent
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from _pytest import __version__ as version from _pytest import __version__ as full_version
version = full_version.split("+")[0]
if TYPE_CHECKING: if TYPE_CHECKING:
import sphinx.application import sphinx.application
@ -189,6 +193,7 @@ nitpick_ignore = [
("py:class", "SubRequest"), ("py:class", "SubRequest"),
("py:class", "TerminalReporter"), ("py:class", "TerminalReporter"),
("py:class", "_pytest._code.code.TerminalRepr"), ("py:class", "_pytest._code.code.TerminalRepr"),
("py:class", "TerminalRepr"),
("py:class", "_pytest.fixtures.FixtureFunctionMarker"), ("py:class", "_pytest.fixtures.FixtureFunctionMarker"),
("py:class", "_pytest.logging.LogCaptureHandler"), ("py:class", "_pytest.logging.LogCaptureHandler"),
("py:class", "_pytest.mark.structures.ParameterSet"), ("py:class", "_pytest.mark.structures.ParameterSet"),
@ -210,13 +215,16 @@ nitpick_ignore = [
("py:class", "_PluggyPlugin"), ("py:class", "_PluggyPlugin"),
# TypeVars # TypeVars
("py:class", "_pytest._code.code.E"), ("py:class", "_pytest._code.code.E"),
("py:class", "E"), # due to delayed annotation
("py:class", "_pytest.fixtures.FixtureFunction"), ("py:class", "_pytest.fixtures.FixtureFunction"),
("py:class", "_pytest.nodes._NodeType"), ("py:class", "_pytest.nodes._NodeType"),
("py:class", "_NodeType"), # due to delayed annotation
("py:class", "_pytest.python_api.E"), ("py:class", "_pytest.python_api.E"),
("py:class", "_pytest.recwarn.T"), ("py:class", "_pytest.recwarn.T"),
("py:class", "_pytest.runner.TResult"), ("py:class", "_pytest.runner.TResult"),
("py:obj", "_pytest.fixtures.FixtureValue"), ("py:obj", "_pytest.fixtures.FixtureValue"),
("py:obj", "_pytest.stash.T"), ("py:obj", "_pytest.stash.T"),
("py:class", "_ScopeName"),
] ]
@ -455,7 +463,7 @@ intersphinx_mapping = {
} }
def setup(app: "sphinx.application.Sphinx") -> None: def setup(app: sphinx.application.Sphinx) -> None:
app.add_crossref_type( app.add_crossref_type(
"fixture", "fixture",
"fixture", "fixture",

View File

@ -1 +1,4 @@
from __future__ import annotations
collect_ignore = ["conf.py"] collect_ignore = ["conf.py"]

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest
from pytest import raises from pytest import raises

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os.path import os.path
import pytest import pytest

View File

@ -1,3 +1,6 @@
from __future__ import annotations
hello = "world" hello = "world"

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os.path import os.path
import shutil import shutil

View File

@ -1,3 +1,6 @@
from __future__ import annotations
def setup_module(module): def setup_module(module):
module.TestStateFullThing.classcount = 0 module.TestStateFullThing.classcount = 0

View File

@ -1 +1,4 @@
from __future__ import annotations
collect_ignore = ["nonpython", "customdirectory"] collect_ignore = ["nonpython", "customdirectory"]

View File

@ -1,4 +1,6 @@
# content of conftest.py # content of conftest.py
from __future__ import annotations
import json import json
import pytest import pytest

View File

@ -1,3 +1,6 @@
# content of test_first.py # content of test_first.py
from __future__ import annotations
def test_1(): def test_1():
pass pass

View File

@ -1,3 +1,6 @@
# content of test_second.py # content of test_second.py
from __future__ import annotations
def test_2(): def test_2():
pass pass

View File

@ -1,3 +1,6 @@
# content of test_third.py # content of test_third.py
from __future__ import annotations
def test_3(): def test_3():
pass pass

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,6 +1,8 @@
"""Module containing a parametrized tests testing cross-python serialization """Module containing a parametrized tests testing cross-python serialization
via the pickle module.""" via the pickle module."""
from __future__ import annotations
import shutil import shutil
import subprocess import subprocess
import textwrap import textwrap

View File

@ -1,4 +1,6 @@
# content of conftest.py # content of conftest.py
from __future__ import annotations
import pytest import pytest

View File

@ -1,5 +1,6 @@
# run this with $ pytest --collect-only test_collectonly.py # run this with $ pytest --collect-only test_collectonly.py
# #
from __future__ import annotations
def test_function(): def test_function():

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import pytest import pytest

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json import json
from pathlib import Path from pathlib import Path

View File

@ -6,15 +6,15 @@ keywords = [
"test", "test",
"unittest", "unittest",
] ]
license = {text = "MIT"} license = { text = "MIT" }
authors = [ authors = [
{name = "Holger Krekel"}, { name = "Holger Krekel" },
{name = "Bruno Oliveira"}, { name = "Bruno Oliveira" },
{name = "Ronny Pfannschmidt"}, { name = "Ronny Pfannschmidt" },
{name = "Floris Bruynooghe"}, { name = "Floris Bruynooghe" },
{name = "Brianna Laugher"}, { name = "Brianna Laugher" },
{name = "Florian Bruhin"}, { name = "Florian Bruhin" },
{name = "Others (See AUTHORS)"}, { name = "Others (See AUTHORS)" },
] ]
requires-python = ">=3.8" requires-python = ">=3.8"
classifiers = [ classifiers = [
@ -99,6 +99,7 @@ select = [
"E", # pycodestyle "E", # pycodestyle
"F", # pyflakes "F", # pyflakes
"I", # isort "I", # isort
"FA100", # add future annotations
"PYI", # flake8-pyi "PYI", # flake8-pyi
"UP", # pyupgrade "UP", # pyupgrade
"RUF", # ruff "RUF", # ruff
@ -169,7 +170,7 @@ lines-after-imports = 2
[tool.pylint.main] [tool.pylint.main]
# Maximum number of characters on a single line. # Maximum number of characters on a single line.
max-line-length = 120 max-line-length = 120
disable= [ disable = [
"abstract-method", "abstract-method",
"arguments-differ", "arguments-differ",
"arguments-renamed", "arguments-renamed",

View File

@ -9,6 +9,8 @@ our CHANGELOG) into Markdown (which is required by GitHub Releases).
Requires Python3.6+. Requires Python3.6+.
""" """
from __future__ import annotations
from pathlib import Path from pathlib import Path
import re import re
import sys import sys

View File

@ -14,6 +14,8 @@ After that, it will create a release using the `release` tox environment, and pu
`pytest bot <pytestbot@gmail.com>` commit author. `pytest bot <pytestbot@gmail.com>` commit author.
""" """
from __future__ import annotations
import argparse import argparse
from pathlib import Path from pathlib import Path
import re import re

View File

@ -1,6 +1,8 @@
# mypy: disallow-untyped-defs # mypy: disallow-untyped-defs
"""Invoke development tasks.""" """Invoke development tasks."""
from __future__ import annotations
import argparse import argparse
import os import os
from pathlib import Path from pathlib import Path

View File

@ -1,4 +1,6 @@
# mypy: disallow-untyped-defs # mypy: disallow-untyped-defs
from __future__ import annotations
import datetime import datetime
import pathlib import pathlib
import re import re

View File

@ -1,3 +1,6 @@
from __future__ import annotations
__all__ = ["__version__", "version_tuple"] __all__ = ["__version__", "version_tuple"]
try: try:

View File

@ -62,13 +62,13 @@ If things do not work right away:
global argcomplete script). global argcomplete script).
""" """
from __future__ import annotations
import argparse import argparse
from glob import glob from glob import glob
import os import os
import sys import sys
from typing import Any from typing import Any
from typing import List
from typing import Optional
class FastFilesCompleter: class FastFilesCompleter:
@ -77,7 +77,7 @@ class FastFilesCompleter:
def __init__(self, directories: bool = True) -> None: def __init__(self, directories: bool = True) -> None:
self.directories = directories self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> List[str]: def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
# Only called on non option completions. # Only called on non option completions.
if os.sep in prefix[1:]: if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep) prefix_dir = len(os.path.dirname(prefix) + os.sep)
@ -104,7 +104,7 @@ if os.environ.get("_ARGCOMPLETE"):
import argcomplete.completers import argcomplete.completers
except ImportError: except ImportError:
sys.exit(-1) sys.exit(-1)
filescompleter: Optional[FastFilesCompleter] = FastFilesCompleter() filescompleter: FastFilesCompleter | None = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None: def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False) argcomplete.autocomplete(parser, always_complete_options=False)

View File

@ -1,5 +1,7 @@
"""Python inspection/code generation API.""" """Python inspection/code generation API."""
from __future__ import annotations
from .code import Code from .code import Code
from .code import ExceptionInfo from .code import ExceptionInfo
from .code import filter_traceback from .code import filter_traceback

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import ast import ast
import dataclasses import dataclasses
import inspect import inspect
@ -17,7 +19,6 @@ from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import ClassVar from typing import ClassVar
from typing import Dict
from typing import Final from typing import Final
from typing import final from typing import final
from typing import Generic from typing import Generic
@ -25,11 +26,9 @@ from typing import Iterable
from typing import List from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import Optional
from typing import overload from typing import overload
from typing import Pattern from typing import Pattern
from typing import Sequence from typing import Sequence
from typing import Set
from typing import SupportsIndex from typing import SupportsIndex
from typing import Tuple from typing import Tuple
from typing import Type from typing import Type
@ -57,6 +56,8 @@ if sys.version_info < (3, 11):
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"] _TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
EXCEPTION_OR_MORE = Union[Type[Exception], Tuple[Type[Exception], ...]]
class Code: class Code:
"""Wrapper around Python code objects.""" """Wrapper around Python code objects."""
@ -67,7 +68,7 @@ class Code:
self.raw = obj self.raw = obj
@classmethod @classmethod
def from_function(cls, obj: object) -> "Code": def from_function(cls, obj: object) -> Code:
return cls(getrawcode(obj)) return cls(getrawcode(obj))
def __eq__(self, other): def __eq__(self, other):
@ -85,7 +86,7 @@ class Code:
return self.raw.co_name return self.raw.co_name
@property @property
def path(self) -> Union[Path, str]: def path(self) -> Path | str:
"""Return a path object pointing to source code, or an ``str`` in """Return a path object pointing to source code, or an ``str`` in
case of ``OSError`` / non-existing file.""" case of ``OSError`` / non-existing file."""
if not self.raw.co_filename: if not self.raw.co_filename:
@ -102,17 +103,17 @@ class Code:
return self.raw.co_filename return self.raw.co_filename
@property @property
def fullsource(self) -> Optional["Source"]: def fullsource(self) -> Source | None:
"""Return a _pytest._code.Source object for the full source file of the code.""" """Return a _pytest._code.Source object for the full source file of the code."""
full, _ = findsource(self.raw) full, _ = findsource(self.raw)
return full return full
def source(self) -> "Source": def source(self) -> Source:
"""Return a _pytest._code.Source object for the code object's source only.""" """Return a _pytest._code.Source object for the code object's source only."""
# return source only for that part of code # return source only for that part of code
return Source(self.raw) return Source(self.raw)
def getargs(self, var: bool = False) -> Tuple[str, ...]: def getargs(self, var: bool = False) -> tuple[str, ...]:
"""Return a tuple with the argument names for the code object. """Return a tuple with the argument names for the code object.
If 'var' is set True also return the names of the variable and If 'var' is set True also return the names of the variable and
@ -141,11 +142,11 @@ class Frame:
return self.raw.f_lineno - 1 return self.raw.f_lineno - 1
@property @property
def f_globals(self) -> Dict[str, Any]: def f_globals(self) -> dict[str, Any]:
return self.raw.f_globals return self.raw.f_globals
@property @property
def f_locals(self) -> Dict[str, Any]: def f_locals(self) -> dict[str, Any]:
return self.raw.f_locals return self.raw.f_locals
@property @property
@ -153,7 +154,7 @@ class Frame:
return Code(self.raw.f_code) return Code(self.raw.f_code)
@property @property
def statement(self) -> "Source": def statement(self) -> Source:
"""Statement this frame is at.""" """Statement this frame is at."""
if self.code.fullsource is None: if self.code.fullsource is None:
return Source("") return Source("")
@ -197,14 +198,14 @@ class TracebackEntry:
def __init__( def __init__(
self, self,
rawentry: TracebackType, rawentry: TracebackType,
repr_style: Optional['Literal["short", "long"]'] = None, repr_style: Literal["short", "long"] | None = None,
) -> None: ) -> None:
self._rawentry: "Final" = rawentry self._rawentry: Final = rawentry
self._repr_style: "Final" = repr_style self._repr_style: Final = repr_style
def with_repr_style( def with_repr_style(
self, repr_style: Optional['Literal["short", "long"]'] self, repr_style: Literal["short", "long"] | None
) -> "TracebackEntry": ) -> TracebackEntry:
return TracebackEntry(self._rawentry, repr_style) return TracebackEntry(self._rawentry, repr_style)
@property @property
@ -223,19 +224,19 @@ class TracebackEntry:
return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1) return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
@property @property
def statement(self) -> "Source": def statement(self) -> Source:
"""_pytest._code.Source object for the current statement.""" """_pytest._code.Source object for the current statement."""
source = self.frame.code.fullsource source = self.frame.code.fullsource
assert source is not None assert source is not None
return source.getstatement(self.lineno) return source.getstatement(self.lineno)
@property @property
def path(self) -> Union[Path, str]: def path(self) -> Path | str:
"""Path to the source code.""" """Path to the source code."""
return self.frame.code.path return self.frame.code.path
@property @property
def locals(self) -> Dict[str, Any]: def locals(self) -> dict[str, Any]:
"""Locals of underlying frame.""" """Locals of underlying frame."""
return self.frame.f_locals return self.frame.f_locals
@ -243,8 +244,8 @@ class TracebackEntry:
return self.frame.code.firstlineno return self.frame.code.firstlineno
def getsource( def getsource(
self, astcache: Optional[Dict[Union[str, Path], ast.AST]] = None self, astcache: dict[str | Path, ast.AST] | None = None
) -> Optional["Source"]: ) -> Source | None:
"""Return failing source code.""" """Return failing source code."""
# we use the passed in astcache to not reparse asttrees # we use the passed in astcache to not reparse asttrees
# within exception info printing # within exception info printing
@ -270,7 +271,7 @@ class TracebackEntry:
source = property(getsource) source = property(getsource)
def ishidden(self, excinfo: Optional["ExceptionInfo[BaseException]"]) -> bool: def ishidden(self, excinfo: ExceptionInfo[BaseException] | None) -> bool:
"""Return True if the current frame has a var __tracebackhide__ """Return True if the current frame has a var __tracebackhide__
resolving to True. resolving to True.
@ -279,9 +280,7 @@ class TracebackEntry:
Mostly for internal use. Mostly for internal use.
""" """
tbh: Union[bool, Callable[[Optional[ExceptionInfo[BaseException]]], bool]] = ( tbh: bool | Callable[[ExceptionInfo[BaseException] | None], bool] = False
False
)
for maybe_ns_dct in (self.frame.f_locals, self.frame.f_globals): for maybe_ns_dct in (self.frame.f_locals, self.frame.f_globals):
# in normal cases, f_locals and f_globals are dictionaries # in normal cases, f_locals and f_globals are dictionaries
# however via `exec(...)` / `eval(...)` they can be other types # however via `exec(...)` / `eval(...)` they can be other types
@ -326,13 +325,13 @@ class Traceback(List[TracebackEntry]):
def __init__( def __init__(
self, self,
tb: Union[TracebackType, Iterable[TracebackEntry]], tb: TracebackType | Iterable[TracebackEntry],
) -> None: ) -> None:
"""Initialize from given python traceback object and ExceptionInfo.""" """Initialize from given python traceback object and ExceptionInfo."""
if isinstance(tb, TracebackType): if isinstance(tb, TracebackType):
def f(cur: TracebackType) -> Iterable[TracebackEntry]: def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_: Optional[TracebackType] = cur cur_: TracebackType | None = cur
while cur_ is not None: while cur_ is not None:
yield TracebackEntry(cur_) yield TracebackEntry(cur_)
cur_ = cur_.tb_next cur_ = cur_.tb_next
@ -343,11 +342,11 @@ class Traceback(List[TracebackEntry]):
def cut( def cut(
self, self,
path: Optional[Union["os.PathLike[str]", str]] = None, path: os.PathLike[str] | str | None = None,
lineno: Optional[int] = None, lineno: int | None = None,
firstlineno: Optional[int] = None, firstlineno: int | None = None,
excludepath: Optional["os.PathLike[str]"] = None, excludepath: os.PathLike[str] | None = None,
) -> "Traceback": ) -> Traceback:
"""Return a Traceback instance wrapping part of this Traceback. """Return a Traceback instance wrapping part of this Traceback.
By providing any combination of path, lineno and firstlineno, the By providing any combination of path, lineno and firstlineno, the
@ -378,14 +377,12 @@ class Traceback(List[TracebackEntry]):
return self return self
@overload @overload
def __getitem__(self, key: "SupportsIndex") -> TracebackEntry: ... def __getitem__(self, key: SupportsIndex) -> TracebackEntry: ...
@overload @overload
def __getitem__(self, key: slice) -> "Traceback": ... def __getitem__(self, key: slice) -> Traceback: ...
def __getitem__( def __getitem__(self, key: SupportsIndex | slice) -> TracebackEntry | Traceback:
self, key: Union["SupportsIndex", slice]
) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice): if isinstance(key, slice):
return self.__class__(super().__getitem__(key)) return self.__class__(super().__getitem__(key))
else: else:
@ -393,12 +390,9 @@ class Traceback(List[TracebackEntry]):
def filter( def filter(
self, self,
excinfo_or_fn: Union[ excinfo_or_fn: ExceptionInfo[BaseException] | Callable[[TracebackEntry], bool],
"ExceptionInfo[BaseException]",
Callable[[TracebackEntry], bool],
],
/, /,
) -> "Traceback": ) -> Traceback:
"""Return a Traceback instance with certain items removed. """Return a Traceback instance with certain items removed.
If the filter is an `ExceptionInfo`, removes all the ``TracebackEntry``s If the filter is an `ExceptionInfo`, removes all the ``TracebackEntry``s
@ -414,10 +408,10 @@ class Traceback(List[TracebackEntry]):
fn = excinfo_or_fn fn = excinfo_or_fn
return Traceback(filter(fn, self)) return Traceback(filter(fn, self))
def recursionindex(self) -> Optional[int]: def recursionindex(self) -> int | None:
"""Return the index of the frame/TracebackEntry where recursion originates if """Return the index of the frame/TracebackEntry where recursion originates if
appropriate, None if no recursion occurred.""" appropriate, None if no recursion occurred."""
cache: Dict[Tuple[Any, int, int], List[Dict[str, Any]]] = {} cache: dict[tuple[Any, int, int], list[dict[str, Any]]] = {}
for i, entry in enumerate(self): for i, entry in enumerate(self):
# id for the code.raw is needed to work around # id for the code.raw is needed to work around
# the strange metaprogramming in the decorator lib from pypi # the strange metaprogramming in the decorator lib from pypi
@ -445,15 +439,15 @@ class ExceptionInfo(Generic[E]):
_assert_start_repr: ClassVar = "AssertionError('assert " _assert_start_repr: ClassVar = "AssertionError('assert "
_excinfo: Optional[Tuple[Type["E"], "E", TracebackType]] _excinfo: tuple[type[E], E, TracebackType] | None
_striptext: str _striptext: str
_traceback: Optional[Traceback] _traceback: Traceback | None
def __init__( def __init__(
self, self,
excinfo: Optional[Tuple[Type["E"], "E", TracebackType]], excinfo: tuple[type[E], E, TracebackType] | None,
striptext: str = "", striptext: str = "",
traceback: Optional[Traceback] = None, traceback: Traceback | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -469,8 +463,8 @@ class ExceptionInfo(Generic[E]):
# This is OK to ignore because this class is (conceptually) readonly. # This is OK to ignore because this class is (conceptually) readonly.
# See https://github.com/python/mypy/issues/7049. # See https://github.com/python/mypy/issues/7049.
exception: E, # type: ignore[misc] exception: E, # type: ignore[misc]
exprinfo: Optional[str] = None, exprinfo: str | None = None,
) -> "ExceptionInfo[E]": ) -> ExceptionInfo[E]:
"""Return an ExceptionInfo for an existing exception. """Return an ExceptionInfo for an existing exception.
The exception must have a non-``None`` ``__traceback__`` attribute, The exception must have a non-``None`` ``__traceback__`` attribute,
@ -495,9 +489,9 @@ class ExceptionInfo(Generic[E]):
@classmethod @classmethod
def from_exc_info( def from_exc_info(
cls, cls,
exc_info: Tuple[Type[E], E, TracebackType], exc_info: tuple[type[E], E, TracebackType],
exprinfo: Optional[str] = None, exprinfo: str | None = None,
) -> "ExceptionInfo[E]": ) -> ExceptionInfo[E]:
"""Like :func:`from_exception`, but using old-style exc_info tuple.""" """Like :func:`from_exception`, but using old-style exc_info tuple."""
_striptext = "" _striptext = ""
if exprinfo is None and isinstance(exc_info[1], AssertionError): if exprinfo is None and isinstance(exc_info[1], AssertionError):
@ -510,9 +504,7 @@ class ExceptionInfo(Generic[E]):
return cls(exc_info, _striptext, _ispytest=True) return cls(exc_info, _striptext, _ispytest=True)
@classmethod @classmethod
def from_current( def from_current(cls, exprinfo: str | None = None) -> ExceptionInfo[BaseException]:
cls, exprinfo: Optional[str] = None
) -> "ExceptionInfo[BaseException]":
"""Return an ExceptionInfo matching the current traceback. """Return an ExceptionInfo matching the current traceback.
.. warning:: .. warning::
@ -532,17 +524,17 @@ class ExceptionInfo(Generic[E]):
return ExceptionInfo.from_exc_info(exc_info, exprinfo) return ExceptionInfo.from_exc_info(exc_info, exprinfo)
@classmethod @classmethod
def for_later(cls) -> "ExceptionInfo[E]": def for_later(cls) -> ExceptionInfo[E]:
"""Return an unfilled ExceptionInfo.""" """Return an unfilled ExceptionInfo."""
return cls(None, _ispytest=True) return cls(None, _ispytest=True)
def fill_unfilled(self, exc_info: Tuple[Type[E], E, TracebackType]) -> None: def fill_unfilled(self, exc_info: tuple[type[E], E, TracebackType]) -> None:
"""Fill an unfilled ExceptionInfo created with ``for_later()``.""" """Fill an unfilled ExceptionInfo created with ``for_later()``."""
assert self._excinfo is None, "ExceptionInfo was already filled" assert self._excinfo is None, "ExceptionInfo was already filled"
self._excinfo = exc_info self._excinfo = exc_info
@property @property
def type(self) -> Type[E]: def type(self) -> type[E]:
"""The exception class.""" """The exception class."""
assert ( assert (
self._excinfo is not None self._excinfo is not None
@ -605,16 +597,14 @@ class ExceptionInfo(Generic[E]):
text = text[len(self._striptext) :] text = text[len(self._striptext) :]
return text return text
def errisinstance( def errisinstance(self, exc: EXCEPTION_OR_MORE) -> bool:
self, exc: Union[Type[BaseException], Tuple[Type[BaseException], ...]]
) -> bool:
"""Return True if the exception is an instance of exc. """Return True if the exception is an instance of exc.
Consider using ``isinstance(excinfo.value, exc)`` instead. Consider using ``isinstance(excinfo.value, exc)`` instead.
""" """
return isinstance(self.value, exc) return isinstance(self.value, exc)
def _getreprcrash(self) -> Optional["ReprFileLocation"]: def _getreprcrash(self) -> ReprFileLocation | None:
# Find last non-hidden traceback entry that led to the exception of the # Find last non-hidden traceback entry that led to the exception of the
# traceback, or None if all hidden. # traceback, or None if all hidden.
for i in range(-1, -len(self.traceback) - 1, -1): for i in range(-1, -len(self.traceback) - 1, -1):
@ -630,13 +620,12 @@ class ExceptionInfo(Generic[E]):
showlocals: bool = False, showlocals: bool = False,
style: _TracebackStyle = "long", style: _TracebackStyle = "long",
abspath: bool = False, abspath: bool = False,
tbfilter: Union[ tbfilter: bool
bool, Callable[["ExceptionInfo[BaseException]"], Traceback] | Callable[[ExceptionInfo[BaseException]], _pytest._code.code.Traceback] = True,
] = True,
funcargs: bool = False, funcargs: bool = False,
truncate_locals: bool = True, truncate_locals: bool = True,
chain: bool = True, chain: bool = True,
) -> Union["ReprExceptionInfo", "ExceptionChainRepr"]: ) -> ReprExceptionInfo | ExceptionChainRepr:
"""Return str()able representation of this exception info. """Return str()able representation of this exception info.
:param bool showlocals: :param bool showlocals:
@ -714,7 +703,7 @@ class ExceptionInfo(Generic[E]):
] ]
) )
def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]": def match(self, regexp: str | Pattern[str]) -> Literal[True]:
"""Check whether the regular expression `regexp` matches the string """Check whether the regular expression `regexp` matches the string
representation of the exception using :func:`python:re.search`. representation of the exception using :func:`python:re.search`.
@ -732,9 +721,9 @@ class ExceptionInfo(Generic[E]):
def _group_contains( def _group_contains(
self, self,
exc_group: BaseExceptionGroup[BaseException], exc_group: BaseExceptionGroup[BaseException],
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], expected_exception: EXCEPTION_OR_MORE,
match: Union[str, Pattern[str], None], match: str | Pattern[str] | None,
target_depth: Optional[int] = None, target_depth: int | None = None,
current_depth: int = 1, current_depth: int = 1,
) -> bool: ) -> bool:
"""Return `True` if a `BaseExceptionGroup` contains a matching exception.""" """Return `True` if a `BaseExceptionGroup` contains a matching exception."""
@ -761,10 +750,10 @@ class ExceptionInfo(Generic[E]):
def group_contains( def group_contains(
self, self,
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], expected_exception: EXCEPTION_OR_MORE,
*, *,
match: Union[str, Pattern[str], None] = None, match: str | Pattern[str] | None = None,
depth: Optional[int] = None, depth: int | None = None,
) -> bool: ) -> bool:
"""Check whether a captured exception group contains a matching exception. """Check whether a captured exception group contains a matching exception.
@ -806,15 +795,15 @@ class FormattedExcinfo:
showlocals: bool = False showlocals: bool = False
style: _TracebackStyle = "long" style: _TracebackStyle = "long"
abspath: bool = True abspath: bool = True
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]] = True tbfilter: bool | Callable[[ExceptionInfo[BaseException]], Traceback] = True
funcargs: bool = False funcargs: bool = False
truncate_locals: bool = True truncate_locals: bool = True
chain: bool = True chain: bool = True
astcache: Dict[Union[str, Path], ast.AST] = dataclasses.field( astcache: dict[str | Path, ast.AST] = dataclasses.field(
default_factory=dict, init=False, repr=False default_factory=dict, init=False, repr=False
) )
def _getindent(self, source: "Source") -> int: def _getindent(self, source: Source) -> int:
# Figure out indent for the given source. # Figure out indent for the given source.
try: try:
s = str(source.getstatement(len(source) - 1)) s = str(source.getstatement(len(source) - 1))
@ -829,13 +818,13 @@ class FormattedExcinfo:
return 0 return 0
return 4 + (len(s) - len(s.lstrip())) return 4 + (len(s) - len(s.lstrip()))
def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]: def _getentrysource(self, entry: TracebackEntry) -> Source | None:
source = entry.getsource(self.astcache) source = entry.getsource(self.astcache)
if source is not None: if source is not None:
source = source.deindent() source = source.deindent()
return source return source
def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]: def repr_args(self, entry: TracebackEntry) -> ReprFuncArgs | None:
if self.funcargs: if self.funcargs:
args = [] args = []
for argname, argvalue in entry.frame.getargs(var=True): for argname, argvalue in entry.frame.getargs(var=True):
@ -845,11 +834,11 @@ class FormattedExcinfo:
def get_source( def get_source(
self, self,
source: Optional["Source"], source: Source | None,
line_index: int = -1, line_index: int = -1,
excinfo: Optional[ExceptionInfo[BaseException]] = None, excinfo: ExceptionInfo[BaseException] | None = None,
short: bool = False, short: bool = False,
) -> List[str]: ) -> list[str]:
"""Return formatted and marked up source lines.""" """Return formatted and marked up source lines."""
lines = [] lines = []
if source is not None and line_index < 0: if source is not None and line_index < 0:
@ -878,7 +867,7 @@ class FormattedExcinfo:
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
indent: int = 4, indent: int = 4,
markall: bool = False, markall: bool = False,
) -> List[str]: ) -> list[str]:
lines = [] lines = []
indentstr = " " * indent indentstr = " " * indent
# Get the real exception information out. # Get the real exception information out.
@ -890,7 +879,7 @@ class FormattedExcinfo:
failindent = indentstr failindent = indentstr
return lines return lines
def repr_locals(self, locals: Mapping[str, object]) -> Optional["ReprLocals"]: def repr_locals(self, locals: Mapping[str, object]) -> ReprLocals | None:
if self.showlocals: if self.showlocals:
lines = [] lines = []
keys = [loc for loc in locals if loc[0] != "@"] keys = [loc for loc in locals if loc[0] != "@"]
@ -918,10 +907,10 @@ class FormattedExcinfo:
def repr_traceback_entry( def repr_traceback_entry(
self, self,
entry: Optional[TracebackEntry], entry: TracebackEntry | None,
excinfo: Optional[ExceptionInfo[BaseException]] = None, excinfo: ExceptionInfo[BaseException] | None = None,
) -> "ReprEntry": ) -> ReprEntry:
lines: List[str] = [] lines: list[str] = []
style = ( style = (
entry._repr_style entry._repr_style
if entry is not None and entry._repr_style is not None if entry is not None and entry._repr_style is not None
@ -956,7 +945,7 @@ class FormattedExcinfo:
lines.extend(self.get_exconly(excinfo, indent=4)) lines.extend(self.get_exconly(excinfo, indent=4))
return ReprEntry(lines, None, None, None, style) return ReprEntry(lines, None, None, None, style)
def _makepath(self, path: Union[Path, str]) -> str: def _makepath(self, path: Path | str) -> str:
if not self.abspath and isinstance(path, Path): if not self.abspath and isinstance(path, Path):
try: try:
np = bestrelpath(Path.cwd(), path) np = bestrelpath(Path.cwd(), path)
@ -966,7 +955,7 @@ class FormattedExcinfo:
return np return np
return str(path) return str(path)
def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback": def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> ReprTraceback:
traceback = excinfo.traceback traceback = excinfo.traceback
if callable(self.tbfilter): if callable(self.tbfilter):
traceback = self.tbfilter(excinfo) traceback = self.tbfilter(excinfo)
@ -997,7 +986,7 @@ class FormattedExcinfo:
def _truncate_recursive_traceback( def _truncate_recursive_traceback(
self, traceback: Traceback self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]: ) -> tuple[Traceback, str | None]:
"""Truncate the given recursive traceback trying to find the starting """Truncate the given recursive traceback trying to find the starting
point of the recursion. point of the recursion.
@ -1014,7 +1003,7 @@ class FormattedExcinfo:
recursionindex = traceback.recursionindex() recursionindex = traceback.recursionindex()
except Exception as e: except Exception as e:
max_frames = 10 max_frames = 10
extraline: Optional[str] = ( extraline: str | None = (
"!!! Recursion error detected, but an error occurred locating the origin of recursion.\n" "!!! Recursion error detected, but an error occurred locating the origin of recursion.\n"
" The following exception happened when comparing locals in the stack frame:\n" " The following exception happened when comparing locals in the stack frame:\n"
f" {type(e).__name__}: {e!s}\n" f" {type(e).__name__}: {e!s}\n"
@ -1032,16 +1021,12 @@ class FormattedExcinfo:
return traceback, extraline return traceback, extraline
def repr_excinfo( def repr_excinfo(self, excinfo: ExceptionInfo[BaseException]) -> ExceptionChainRepr:
self, excinfo: ExceptionInfo[BaseException] repr_chain: list[tuple[ReprTraceback, ReprFileLocation | None, str | None]] = []
) -> "ExceptionChainRepr": e: BaseException | None = excinfo.value
repr_chain: List[ excinfo_: ExceptionInfo[BaseException] | None = excinfo
Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]
] = []
e: Optional[BaseException] = excinfo.value
excinfo_: Optional[ExceptionInfo[BaseException]] = excinfo
descr = None descr = None
seen: Set[int] = set() seen: set[int] = set()
while e is not None and id(e) not in seen: while e is not None and id(e) not in seen:
seen.add(id(e)) seen.add(id(e))
@ -1050,7 +1035,7 @@ class FormattedExcinfo:
# full support for exception groups added to ExceptionInfo. # full support for exception groups added to ExceptionInfo.
# See https://github.com/pytest-dev/pytest/issues/9159 # See https://github.com/pytest-dev/pytest/issues/9159
if isinstance(e, BaseExceptionGroup): if isinstance(e, BaseExceptionGroup):
reprtraceback: Union[ReprTracebackNative, ReprTraceback] = ( reprtraceback: ReprTracebackNative | ReprTraceback = (
ReprTracebackNative( ReprTracebackNative(
traceback.format_exception( traceback.format_exception(
type(excinfo_.value), type(excinfo_.value),
@ -1108,9 +1093,9 @@ class TerminalRepr:
@dataclasses.dataclass(eq=False) @dataclasses.dataclass(eq=False)
class ExceptionRepr(TerminalRepr): class ExceptionRepr(TerminalRepr):
# Provided by subclasses. # Provided by subclasses.
reprtraceback: "ReprTraceback" reprtraceback: ReprTraceback
reprcrash: Optional["ReprFileLocation"] reprcrash: ReprFileLocation | None
sections: List[Tuple[str, str, str]] = dataclasses.field( sections: list[tuple[str, str, str]] = dataclasses.field(
init=False, default_factory=list init=False, default_factory=list
) )
@ -1125,13 +1110,11 @@ class ExceptionRepr(TerminalRepr):
@dataclasses.dataclass(eq=False) @dataclasses.dataclass(eq=False)
class ExceptionChainRepr(ExceptionRepr): class ExceptionChainRepr(ExceptionRepr):
chain: Sequence[Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]] chain: Sequence[tuple[ReprTraceback, ReprFileLocation | None, str | None]]
def __init__( def __init__(
self, self,
chain: Sequence[ chain: Sequence[tuple[ReprTraceback, ReprFileLocation | None, str | None]],
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None: ) -> None:
# reprcrash and reprtraceback of the outermost (the newest) exception # reprcrash and reprtraceback of the outermost (the newest) exception
# in the chain. # in the chain.
@ -1152,8 +1135,8 @@ class ExceptionChainRepr(ExceptionRepr):
@dataclasses.dataclass(eq=False) @dataclasses.dataclass(eq=False)
class ReprExceptionInfo(ExceptionRepr): class ReprExceptionInfo(ExceptionRepr):
reprtraceback: "ReprTraceback" reprtraceback: ReprTraceback
reprcrash: Optional["ReprFileLocation"] reprcrash: ReprFileLocation | None
def toterminal(self, tw: TerminalWriter) -> None: def toterminal(self, tw: TerminalWriter) -> None:
self.reprtraceback.toterminal(tw) self.reprtraceback.toterminal(tw)
@ -1162,8 +1145,8 @@ class ReprExceptionInfo(ExceptionRepr):
@dataclasses.dataclass(eq=False) @dataclasses.dataclass(eq=False)
class ReprTraceback(TerminalRepr): class ReprTraceback(TerminalRepr):
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]] reprentries: Sequence[ReprEntry | ReprEntryNative]
extraline: Optional[str] extraline: str | None
style: _TracebackStyle style: _TracebackStyle
entrysep: ClassVar = "_ " entrysep: ClassVar = "_ "
@ -1207,9 +1190,9 @@ class ReprEntryNative(TerminalRepr):
@dataclasses.dataclass(eq=False) @dataclasses.dataclass(eq=False)
class ReprEntry(TerminalRepr): class ReprEntry(TerminalRepr):
lines: Sequence[str] lines: Sequence[str]
reprfuncargs: Optional["ReprFuncArgs"] reprfuncargs: ReprFuncArgs | None
reprlocals: Optional["ReprLocals"] reprlocals: ReprLocals | None
reprfileloc: Optional["ReprFileLocation"] reprfileloc: ReprFileLocation | None
style: _TracebackStyle style: _TracebackStyle
def _write_entry_lines(self, tw: TerminalWriter) -> None: def _write_entry_lines(self, tw: TerminalWriter) -> None:
@ -1233,9 +1216,9 @@ class ReprEntry(TerminalRepr):
# such as "> assert 0" # such as "> assert 0"
fail_marker = f"{FormattedExcinfo.fail_marker} " fail_marker = f"{FormattedExcinfo.fail_marker} "
indent_size = len(fail_marker) indent_size = len(fail_marker)
indents: List[str] = [] indents: list[str] = []
source_lines: List[str] = [] source_lines: list[str] = []
failure_lines: List[str] = [] failure_lines: list[str] = []
for index, line in enumerate(self.lines): for index, line in enumerate(self.lines):
is_failure_line = line.startswith(fail_marker) is_failure_line = line.startswith(fail_marker)
if is_failure_line: if is_failure_line:
@ -1314,7 +1297,7 @@ class ReprLocals(TerminalRepr):
@dataclasses.dataclass(eq=False) @dataclasses.dataclass(eq=False)
class ReprFuncArgs(TerminalRepr): class ReprFuncArgs(TerminalRepr):
args: Sequence[Tuple[str, object]] args: Sequence[tuple[str, object]]
def toterminal(self, tw: TerminalWriter) -> None: def toterminal(self, tw: TerminalWriter) -> None:
if self.args: if self.args:
@ -1335,7 +1318,7 @@ class ReprFuncArgs(TerminalRepr):
tw.line("") tw.line("")
def getfslineno(obj: object) -> Tuple[Union[str, Path], int]: def getfslineno(obj: object) -> tuple[str | Path, int]:
"""Return source location (path, lineno) for the given object. """Return source location (path, lineno) for the given object.
If the source cannot be determined return ("", -1). If the source cannot be determined return ("", -1).

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import ast import ast
from bisect import bisect_right from bisect import bisect_right
import inspect import inspect
@ -7,11 +9,7 @@ import tokenize
import types import types
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Optional
from typing import overload from typing import overload
from typing import Tuple
from typing import Union
import warnings import warnings
@ -23,7 +21,7 @@ class Source:
def __init__(self, obj: object = None) -> None: def __init__(self, obj: object = None) -> None:
if not obj: if not obj:
self.lines: List[str] = [] self.lines: list[str] = []
elif isinstance(obj, Source): elif isinstance(obj, Source):
self.lines = obj.lines self.lines = obj.lines
elif isinstance(obj, (tuple, list)): elif isinstance(obj, (tuple, list)):
@ -50,9 +48,9 @@ class Source:
def __getitem__(self, key: int) -> str: ... def __getitem__(self, key: int) -> str: ...
@overload @overload
def __getitem__(self, key: slice) -> "Source": ... def __getitem__(self, key: slice) -> Source: ...
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: def __getitem__(self, key: int | slice) -> str | Source:
if isinstance(key, int): if isinstance(key, int):
return self.lines[key] return self.lines[key]
else: else:
@ -68,7 +66,7 @@ class Source:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.lines) return len(self.lines)
def strip(self) -> "Source": def strip(self) -> Source:
"""Return new Source object with trailing and leading blank lines removed.""" """Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self) start, end = 0, len(self)
while start < end and not self.lines[start].strip(): while start < end and not self.lines[start].strip():
@ -79,20 +77,20 @@ class Source:
source.lines[:] = self.lines[start:end] source.lines[:] = self.lines[start:end]
return source return source
def indent(self, indent: str = " " * 4) -> "Source": def indent(self, indent: str = " " * 4) -> Source:
"""Return a copy of the source object with all lines indented by the """Return a copy of the source object with all lines indented by the
given indent-string.""" given indent-string."""
newsource = Source() newsource = Source()
newsource.lines = [(indent + line) for line in self.lines] newsource.lines = [(indent + line) for line in self.lines]
return newsource return newsource
def getstatement(self, lineno: int) -> "Source": def getstatement(self, lineno: int) -> Source:
"""Return Source statement which contains the given linenumber """Return Source statement which contains the given linenumber
(counted from 0).""" (counted from 0)."""
start, end = self.getstatementrange(lineno) start, end = self.getstatementrange(lineno)
return self[start:end] return self[start:end]
def getstatementrange(self, lineno: int) -> Tuple[int, int]: def getstatementrange(self, lineno: int) -> tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region """Return (start, end) tuple which spans the minimal statement region
which containing the given lineno.""" which containing the given lineno."""
if not (0 <= lineno < len(self)): if not (0 <= lineno < len(self)):
@ -100,7 +98,7 @@ class Source:
ast, start, end = getstatementrange_ast(lineno, self) ast, start, end = getstatementrange_ast(lineno, self)
return start, end return start, end
def deindent(self) -> "Source": def deindent(self) -> Source:
"""Return a new Source object deindented.""" """Return a new Source object deindented."""
newsource = Source() newsource = Source()
newsource.lines[:] = deindent(self.lines) newsource.lines[:] = deindent(self.lines)
@ -115,7 +113,7 @@ class Source:
# #
def findsource(obj) -> Tuple[Optional[Source], int]: def findsource(obj) -> tuple[Source | None, int]:
try: try:
sourcelines, lineno = inspect.findsource(obj) sourcelines, lineno = inspect.findsource(obj)
except Exception: except Exception:
@ -138,14 +136,14 @@ def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
raise TypeError(f"could not get code object for {obj!r}") raise TypeError(f"could not get code object for {obj!r}")
def deindent(lines: Iterable[str]) -> List[str]: def deindent(lines: Iterable[str]) -> list[str]:
return textwrap.dedent("\n".join(lines)).splitlines() return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
# Flatten all statements and except handlers into one lineno-list. # Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1. # AST's line numbers start indexing at 1.
values: List[int] = [] values: list[int] = []
for x in ast.walk(node): for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)): if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# The lineno points to the class/def, so need to include the decorators. # The lineno points to the class/def, so need to include the decorators.
@ -154,7 +152,7 @@ def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[i
values.append(d.lineno - 1) values.append(d.lineno - 1)
values.append(x.lineno - 1) values.append(x.lineno - 1)
for name in ("finalbody", "orelse"): for name in ("finalbody", "orelse"):
val: Optional[List[ast.stmt]] = getattr(x, name, None) val: list[ast.stmt] | None = getattr(x, name, None)
if val: if val:
# Treat the finally/orelse part as its own statement. # Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1) values.append(val[0].lineno - 1 - 1)
@ -172,8 +170,8 @@ def getstatementrange_ast(
lineno: int, lineno: int,
source: Source, source: Source,
assertion: bool = False, assertion: bool = False,
astnode: Optional[ast.AST] = None, astnode: ast.AST | None = None,
) -> Tuple[ast.AST, int, int]: ) -> tuple[ast.AST, int, int]:
if astnode is None: if astnode is None:
content = str(source) content = str(source)
# See #4260: # See #4260:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from .terminalwriter import get_terminal_width from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter from .terminalwriter import TerminalWriter

View File

@ -13,6 +13,8 @@
# tuples with fairly non-descriptive content. This is modeled very much # tuples with fairly non-descriptive content. This is modeled very much
# after Lisp/Scheme - style pretty-printing of lists. If you find it # after Lisp/Scheme - style pretty-printing of lists. If you find it
# useful, thank small children who sleep at night. # useful, thank small children who sleep at night.
from __future__ import annotations
import collections as _collections import collections as _collections
import dataclasses as _dataclasses import dataclasses as _dataclasses
from io import StringIO as _StringIO from io import StringIO as _StringIO
@ -20,13 +22,8 @@ import re
import types as _types import types as _types
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict
from typing import IO from typing import IO
from typing import Iterator from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
class _safe_key: class _safe_key:
@ -64,7 +61,7 @@ class PrettyPrinter:
self, self,
indent: int = 4, indent: int = 4,
width: int = 80, width: int = 80,
depth: Optional[int] = None, depth: int | None = None,
) -> None: ) -> None:
"""Handle pretty printing operations onto a stream using a set of """Handle pretty printing operations onto a stream using a set of
configured parameters. configured parameters.
@ -100,7 +97,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
objid = id(object) objid = id(object)
@ -136,7 +133,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
cls_name = object.__class__.__name__ cls_name = object.__class__.__name__
@ -149,9 +146,9 @@ class PrettyPrinter:
self._format_namespace_items(items, stream, indent, allowance, context, level) self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")") stream.write(")")
_dispatch: Dict[ _dispatch: dict[
Callable[..., str], Callable[..., str],
Callable[["PrettyPrinter", Any, IO[str], int, int, Set[int], int], None], Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
] = {} ] = {}
def _pprint_dict( def _pprint_dict(
@ -160,7 +157,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
@ -177,7 +174,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not len(object): if not len(object):
@ -196,7 +193,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write("[") stream.write("[")
@ -211,7 +208,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write("(") stream.write("(")
@ -226,7 +223,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not len(object): if not len(object):
@ -252,7 +249,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
@ -311,7 +308,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
@ -340,7 +337,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
write = stream.write write = stream.write
@ -358,7 +355,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write("mappingproxy(") stream.write("mappingproxy(")
@ -373,7 +370,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if type(object) is _types.SimpleNamespace: if type(object) is _types.SimpleNamespace:
@ -391,11 +388,11 @@ class PrettyPrinter:
def _format_dict_items( def _format_dict_items(
self, self,
items: List[Tuple[Any, Any]], items: list[tuple[Any, Any]],
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not items: if not items:
@ -415,11 +412,11 @@ class PrettyPrinter:
def _format_namespace_items( def _format_namespace_items(
self, self,
items: List[Tuple[Any, Any]], items: list[tuple[Any, Any]],
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not items: if not items:
@ -452,11 +449,11 @@ class PrettyPrinter:
def _format_items( def _format_items(
self, self,
items: List[Any], items: list[Any],
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not items: if not items:
@ -473,7 +470,7 @@ class PrettyPrinter:
write("\n" + " " * indent) write("\n" + " " * indent)
def _repr(self, object: Any, context: Set[int], level: int) -> str: def _repr(self, object: Any, context: set[int], level: int) -> str:
return self._safe_repr(object, context.copy(), self._depth, level) return self._safe_repr(object, context.copy(), self._depth, level)
def _pprint_default_dict( def _pprint_default_dict(
@ -482,7 +479,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
rdf = self._repr(object.default_factory, context, level) rdf = self._repr(object.default_factory, context, level)
@ -498,7 +495,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + "(")
@ -519,7 +516,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])): if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
@ -538,7 +535,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + "(")
@ -557,7 +554,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
@ -570,7 +567,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
@ -583,7 +580,7 @@ class PrettyPrinter:
stream: IO[str], stream: IO[str],
indent: int, indent: int,
allowance: int, allowance: int,
context: Set[int], context: set[int],
level: int, level: int,
) -> None: ) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
@ -591,7 +588,7 @@ class PrettyPrinter:
_dispatch[_collections.UserString.__repr__] = _pprint_user_string _dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr( def _safe_repr(
self, object: Any, context: Set[int], maxlevels: Optional[int], level: int self, object: Any, context: set[int], maxlevels: int | None, level: int
) -> str: ) -> str:
typ = type(object) typ = type(object)
if typ in _builtin_scalars: if typ in _builtin_scalars:
@ -608,7 +605,7 @@ class PrettyPrinter:
if objid in context: if objid in context:
return _recursion(object) return _recursion(object)
context.add(objid) context.add(objid)
components: List[str] = [] components: list[str] = []
append = components.append append = components.append
level += 1 level += 1
for k, v in sorted(object.items(), key=_safe_tuple): for k, v in sorted(object.items(), key=_safe_tuple):

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import pprint import pprint
import reprlib import reprlib
from typing import Optional
def _try_repr_or_str(obj: object) -> str: def _try_repr_or_str(obj: object) -> str:
@ -38,7 +39,7 @@ class SafeRepr(reprlib.Repr):
information on exceptions raised during the call. information on exceptions raised during the call.
""" """
def __init__(self, maxsize: Optional[int], use_ascii: bool = False) -> None: def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
""" """
:param maxsize: :param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis If not None, will truncate the resulting repr to that specific size, using ellipsis
@ -97,7 +98,7 @@ DEFAULT_REPR_MAX_SIZE = 240
def saferepr( def saferepr(
obj: object, maxsize: Optional[int] = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
) -> str: ) -> str:
"""Return a size-limited safe repr-string for the given object. """Return a size-limited safe repr-string for the given object.

View File

@ -1,11 +1,12 @@
"""Helper functions for writing to terminals and files.""" """Helper functions for writing to terminals and files."""
from __future__ import annotations
import os import os
import shutil import shutil
import sys import sys
from typing import final from typing import final
from typing import Literal from typing import Literal
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import TextIO from typing import TextIO
@ -65,7 +66,7 @@ class TerminalWriter:
invert=7, invert=7,
) )
def __init__(self, file: Optional[TextIO] = None) -> None: def __init__(self, file: TextIO | None = None) -> None:
if file is None: if file is None:
file = sys.stdout file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32": if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
@ -79,7 +80,7 @@ class TerminalWriter:
self._file = file self._file = file
self.hasmarkup = should_do_markup(file) self.hasmarkup = should_do_markup(file)
self._current_line = "" self._current_line = ""
self._terminal_width: Optional[int] = None self._terminal_width: int | None = None
self.code_highlight = True self.code_highlight = True
@property @property
@ -110,8 +111,8 @@ class TerminalWriter:
def sep( def sep(
self, self,
sepchar: str, sepchar: str,
title: Optional[str] = None, title: str | None = None,
fullwidth: Optional[int] = None, fullwidth: int | None = None,
**markup: bool, **markup: bool,
) -> None: ) -> None:
if fullwidth is None: if fullwidth is None:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from functools import lru_cache from functools import lru_cache
import unicodedata import unicodedata

View File

@ -1,11 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Support for presenting detailed information in failing assertions.""" """Support for presenting detailed information in failing assertions."""
from __future__ import annotations
import sys import sys
from typing import Any from typing import Any
from typing import Generator from typing import Generator
from typing import List
from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from _pytest.assertion import rewrite from _pytest.assertion import rewrite
@ -94,7 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None: def __init__(self, config: Config, mode) -> None:
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.hook: Optional[rewrite.AssertionRewritingHook] = None self.hook: rewrite.AssertionRewritingHook | None = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
@ -113,7 +113,7 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
return hook return hook
def pytest_collection(session: "Session") -> None: def pytest_collection(session: Session) -> None:
# This hook is only called when test modules are collected # This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist # so for example not in the managing process of pytest-xdist
# (which does not collect test modules). # (which does not collect test modules).
@ -133,7 +133,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
""" """
ihook = item.ihook ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> Optional[str]: def callbinrepr(op, left: object, right: object) -> str | None:
"""Call the pytest_assertrepr_compare hook and prepare the result. """Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the This uses the first result from the hook and then ensures the
@ -179,7 +179,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
util._config = None util._config = None
def pytest_sessionfinish(session: "Session") -> None: def pytest_sessionfinish(session: Session) -> None:
assertstate = session.config.stash.get(assertstate_key, None) assertstate = session.config.stash.get(assertstate_key, None)
if assertstate: if assertstate:
if assertstate.hook is not None: if assertstate.hook is not None:
@ -188,5 +188,5 @@ def pytest_sessionfinish(session: "Session") -> None:
def pytest_assertrepr_compare( def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any config: Config, op: str, left: Any, right: Any
) -> Optional[List[str]]: ) -> list[str] | None:
return util.assertrepr_compare(config=config, op=op, left=left, right=right) return util.assertrepr_compare(config=config, op=op, left=left, right=right)

View File

@ -1,5 +1,7 @@
"""Rewrite assertion AST to produce nice error messages.""" """Rewrite assertion AST to produce nice error messages."""
from __future__ import annotations
import ast import ast
from collections import defaultdict from collections import defaultdict
import errno import errno
@ -18,17 +20,11 @@ import sys
import tokenize import tokenize
import types import types
from typing import Callable from typing import Callable
from typing import Dict
from typing import IO from typing import IO
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
@ -73,17 +69,17 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self.fnpats = config.getini("python_files") self.fnpats = config.getini("python_files")
except ValueError: except ValueError:
self.fnpats = ["test_*.py", "*_test.py"] self.fnpats = ["test_*.py", "*_test.py"]
self.session: Optional[Session] = None self.session: Session | None = None
self._rewritten_names: Dict[str, Path] = {} self._rewritten_names: dict[str, Path] = {}
self._must_rewrite: Set[str] = set() self._must_rewrite: set[str] = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506) # which might result in infinite recursion (#3506)
self._writing_pyc = False self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"} self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache: Dict[str, bool] = {} self._marked_for_rewrite_cache: dict[str, bool] = {}
self._session_paths_checked = False self._session_paths_checked = False
def set_session(self, session: Optional[Session]) -> None: def set_session(self, session: Session | None) -> None:
self.session = session self.session = session
self._session_paths_checked = False self._session_paths_checked = False
@ -93,9 +89,9 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def find_spec( def find_spec(
self, self,
name: str, name: str,
path: Optional[Sequence[Union[str, bytes]]] = None, path: Sequence[str | bytes] | None = None,
target: Optional[types.ModuleType] = None, target: types.ModuleType | None = None,
) -> Optional[importlib.machinery.ModuleSpec]: ) -> importlib.machinery.ModuleSpec | None:
if self._writing_pyc: if self._writing_pyc:
return None return None
state = self.config.stash[assertstate_key] state = self.config.stash[assertstate_key]
@ -132,7 +128,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def create_module( def create_module(
self, spec: importlib.machinery.ModuleSpec self, spec: importlib.machinery.ModuleSpec
) -> Optional[types.ModuleType]: ) -> types.ModuleType | None:
return None # default behaviour is fine return None # default behaviour is fine
def exec_module(self, module: types.ModuleType) -> None: def exec_module(self, module: types.ModuleType) -> None:
@ -177,7 +173,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace(f"found cached rewritten pyc for {fn}") state.trace(f"found cached rewritten pyc for {fn}")
exec(co, module.__dict__) exec(co, module.__dict__)
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool:
"""A fast way to get out of rewriting modules. """A fast way to get out of rewriting modules.
Profiling has shown that the call to PathFinder.find_spec (inside of Profiling has shown that the call to PathFinder.find_spec (inside of
@ -216,7 +212,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace(f"early skip of rewriting module: {name}") state.trace(f"early skip of rewriting module: {name}")
return True return True
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: def _should_rewrite(self, name: str, fn: str, state: AssertionState) -> bool:
# always rewrite conftest files # always rewrite conftest files
if os.path.basename(fn) == "conftest.py": if os.path.basename(fn) == "conftest.py":
state.trace(f"rewriting conftest file: {fn!r}") state.trace(f"rewriting conftest file: {fn!r}")
@ -237,7 +233,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return self._is_marked_for_rewrite(name, state) return self._is_marked_for_rewrite(name, state)
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: def _is_marked_for_rewrite(self, name: str, state: AssertionState) -> bool:
try: try:
return self._marked_for_rewrite_cache[name] return self._marked_for_rewrite_cache[name]
except KeyError: except KeyError:
@ -278,7 +274,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
stacklevel=5, stacklevel=5,
) )
def get_data(self, pathname: Union[str, bytes]) -> bytes: def get_data(self, pathname: str | bytes) -> bytes:
"""Optional PEP302 get_data API.""" """Optional PEP302 get_data API."""
with open(pathname, "rb") as f: with open(pathname, "rb") as f:
return f.read() return f.read()
@ -317,7 +313,7 @@ def _write_pyc_fp(
def _write_pyc( def _write_pyc(
state: "AssertionState", state: AssertionState,
co: types.CodeType, co: types.CodeType,
source_stat: os.stat_result, source_stat: os.stat_result,
pyc: Path, pyc: Path,
@ -341,7 +337,7 @@ def _write_pyc(
return True return True
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeType]:
"""Read and rewrite *fn* and return the code object.""" """Read and rewrite *fn* and return the code object."""
stat = os.stat(fn) stat = os.stat(fn)
source = fn.read_bytes() source = fn.read_bytes()
@ -354,7 +350,7 @@ def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeT
def _read_pyc( def _read_pyc(
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
) -> Optional[types.CodeType]: ) -> types.CodeType | None:
"""Possibly read a pytest pyc containing rewritten code. """Possibly read a pytest pyc containing rewritten code.
Return rewritten code if successful or None if not. Return rewritten code if successful or None if not.
@ -404,8 +400,8 @@ def _read_pyc(
def rewrite_asserts( def rewrite_asserts(
mod: ast.Module, mod: ast.Module,
source: bytes, source: bytes,
module_path: Optional[str] = None, module_path: str | None = None,
config: Optional[Config] = None, config: Config | None = None,
) -> None: ) -> None:
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config, source).run(mod) AssertionRewriter(module_path, config, source).run(mod)
@ -425,7 +421,7 @@ def _saferepr(obj: object) -> str:
return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]: def _get_maxsize_for_saferepr(config: Config | None) -> int | None:
"""Get `maxsize` configuration for saferepr based on the given config object.""" """Get `maxsize` configuration for saferepr based on the given config object."""
if config is None: if config is None:
verbosity = 0 verbosity = 0
@ -543,14 +539,14 @@ def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def _get_assertion_exprs(src: bytes) -> Dict[int, str]: def _get_assertion_exprs(src: bytes) -> dict[int, str]:
"""Return a mapping from {lineno: "assertion test expression"}.""" """Return a mapping from {lineno: "assertion test expression"}."""
ret: Dict[int, str] = {} ret: dict[int, str] = {}
depth = 0 depth = 0
lines: List[str] = [] lines: list[str] = []
assert_lineno: Optional[int] = None assert_lineno: int | None = None
seen_lines: Set[int] = set() seen_lines: set[int] = set()
def _write_and_reset() -> None: def _write_and_reset() -> None:
nonlocal depth, lines, assert_lineno, seen_lines nonlocal depth, lines, assert_lineno, seen_lines
@ -657,7 +653,7 @@ class AssertionRewriter(ast.NodeVisitor):
""" """
def __init__( def __init__(
self, module_path: Optional[str], config: Optional[Config], source: bytes self, module_path: str | None, config: Config | None, source: bytes
) -> None: ) -> None:
super().__init__() super().__init__()
self.module_path = module_path self.module_path = module_path
@ -670,7 +666,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.enable_assertion_pass_hook = False self.enable_assertion_pass_hook = False
self.source = source self.source = source
self.scope: tuple[ast.AST, ...] = () self.scope: tuple[ast.AST, ...] = ()
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], Dict[str, str]] = ( self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = (
defaultdict(dict) defaultdict(dict)
) )
@ -737,7 +733,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Collect asserts. # Collect asserts.
self.scope = (mod,) self.scope = (mod,)
nodes: List[Union[ast.AST, Sentinel]] = [mod] nodes: list[ast.AST | Sentinel] = [mod]
while nodes: while nodes:
node = nodes.pop() node = nodes.pop()
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
@ -749,7 +745,7 @@ class AssertionRewriter(ast.NodeVisitor):
assert isinstance(node, ast.AST) assert isinstance(node, ast.AST)
for name, field in ast.iter_fields(node): for name, field in ast.iter_fields(node):
if isinstance(field, list): if isinstance(field, list):
new: List[ast.AST] = [] new: list[ast.AST] = []
for i, child in enumerate(field): for i, child in enumerate(field):
if isinstance(child, ast.Assert): if isinstance(child, ast.Assert):
# Transform assert. # Transform assert.
@ -821,7 +817,7 @@ class AssertionRewriter(ast.NodeVisitor):
to format a string of %-formatted values as added by to format a string of %-formatted values as added by
.explanation_param(). .explanation_param().
""" """
self.explanation_specifiers: Dict[str, ast.expr] = {} self.explanation_specifiers: dict[str, ast.expr] = {}
self.stack.append(self.explanation_specifiers) self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
@ -835,7 +831,7 @@ class AssertionRewriter(ast.NodeVisitor):
current = self.stack.pop() current = self.stack.pop()
if self.stack: if self.stack:
self.explanation_specifiers = self.stack[-1] self.explanation_specifiers = self.stack[-1]
keys = [ast.Constant(key) for key in current.keys()] keys: list[ast.expr | None] = [ast.Constant(key) for key in current.keys()]
format_dict = ast.Dict(keys, list(current.values())) format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict) form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter)) name = "@py_format" + str(next(self.variable_counter))
@ -844,13 +840,13 @@ class AssertionRewriter(ast.NodeVisitor):
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: def generic_visit(self, node: ast.AST) -> tuple[ast.Name, str]:
"""Handle expressions we don't have custom code for.""" """Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr) assert isinstance(node, ast.expr)
res = self.assign(node) res = self.assign(node)
return res, self.explanation_param(self.display(res)) return res, self.explanation_param(self.display(res))
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
"""Return the AST statements to replace the ast.Assert instance. """Return the AST statements to replace the ast.Assert instance.
This rewrites the test of an assertion to provide This rewrites the test of an assertion to provide
@ -874,15 +870,15 @@ class AssertionRewriter(ast.NodeVisitor):
lineno=assert_.lineno, lineno=assert_.lineno,
) )
self.statements: List[ast.stmt] = [] self.statements: list[ast.stmt] = []
self.variables: List[str] = [] self.variables: list[str] = []
self.variable_counter = itertools.count() self.variable_counter = itertools.count()
if self.enable_assertion_pass_hook: if self.enable_assertion_pass_hook:
self.format_variables: List[str] = [] self.format_variables: list[str] = []
self.stack: List[Dict[str, ast.expr]] = [] self.stack: list[dict[str, ast.expr]] = []
self.expl_stmts: List[ast.stmt] = [] self.expl_stmts: list[ast.stmt] = []
self.push_format_context() self.push_format_context()
# Rewrite assert into a bunch of statements. # Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test) top_condition, explanation = self.visit(assert_.test)
@ -926,13 +922,13 @@ class AssertionRewriter(ast.NodeVisitor):
[*self.expl_stmts, hook_call_pass], [*self.expl_stmts, hook_call_pass],
[], [],
) )
statements_pass = [hook_impl_test] statements_pass: list[ast.stmt] = [hook_impl_test]
# Test for assertion condition # Test for assertion condition
main_test = ast.If(negation, statements_fail, statements_pass) main_test = ast.If(negation, statements_fail, statements_pass)
self.statements.append(main_test) self.statements.append(main_test)
if self.format_variables: if self.format_variables:
variables = [ variables: list[ast.expr] = [
ast.Name(name, ast.Store()) for name in self.format_variables ast.Name(name, ast.Store()) for name in self.format_variables
] ]
clear_format = ast.Assign(variables, ast.Constant(None)) clear_format = ast.Assign(variables, ast.Constant(None))
@ -968,7 +964,7 @@ class AssertionRewriter(ast.NodeVisitor):
ast.copy_location(node, assert_) ast.copy_location(node, assert_)
return self.statements return self.statements
def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]: def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target # This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name() # name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable. # thinks it's acceptable.
@ -980,7 +976,7 @@ class AssertionRewriter(ast.NodeVisitor):
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
return name, self.explanation_param(expr) return name, self.explanation_param(expr)
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or # Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable. # _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], []) locs = ast.Call(self.builtin("locals"), [], [])
@ -990,7 +986,7 @@ class AssertionRewriter(ast.NodeVisitor):
expr = ast.IfExp(test, self.display(name), ast.Constant(name.id)) expr = ast.IfExp(test, self.display(name), ast.Constant(name.id))
return name, self.explanation_param(expr) return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
res_var = self.variable() res_var = self.variable()
expl_list = self.assign(ast.List([], ast.Load())) expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load()) app = ast.Attribute(expl_list, "append", ast.Load())
@ -1002,7 +998,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Process each operand, short-circuiting if needed. # Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values): for i, v in enumerate(boolop.values):
if i: if i:
fail_inner: List[ast.stmt] = [] fail_inner: list[ast.stmt] = []
# cond is set in a prior loop iteration below # cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner self.expl_stmts = fail_inner
@ -1030,7 +1026,7 @@ class AssertionRewriter(ast.NodeVisitor):
cond: ast.expr = res cond: ast.expr = res
if is_or: if is_or:
cond = ast.UnaryOp(ast.Not(), cond) cond = ast.UnaryOp(ast.Not(), cond)
inner: List[ast.stmt] = [] inner: list[ast.stmt] = []
self.statements.append(ast.If(cond, inner, [])) self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner self.statements = body = inner
self.statements = save self.statements = save
@ -1039,13 +1035,13 @@ class AssertionRewriter(ast.NodeVisitor):
expl = self.pop_format_context(expl_template) expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl) return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__] pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand) operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res)) res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,) return res, pattern % (operand_expl,)
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__] symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left) left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right) right_expr, right_expl = self.visit(binop.right)
@ -1053,7 +1049,7 @@ class AssertionRewriter(ast.NodeVisitor):
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation return res, explanation
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
new_func, func_expl = self.visit(call.func) new_func, func_expl = self.visit(call.func)
arg_expls = [] arg_expls = []
new_args = [] new_args = []
@ -1085,13 +1081,13 @@ class AssertionRewriter(ast.NodeVisitor):
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
return res, outer_expl return res, outer_expl
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]:
# A Starred node can appear in a function call. # A Starred node can appear in a function call.
res, expl = self.visit(starred.value) res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx) new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl return new_starred, "*" + expl
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load): if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr) return self.generic_visit(attr)
value, value_expl = self.visit(attr.value) value, value_expl = self.visit(attr.value)
@ -1101,7 +1097,7 @@ class AssertionRewriter(ast.NodeVisitor):
expl = pat % (res_expl, res_expl, value_expl, attr.attr) expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl return res, expl
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
self.push_format_context() self.push_format_context()
# We first check if we have overwritten a variable in the previous assert # We first check if we have overwritten a variable in the previous assert
if isinstance( if isinstance(
@ -1114,11 +1110,11 @@ class AssertionRewriter(ast.NodeVisitor):
if isinstance(comp.left, (ast.Compare, ast.BoolOp)): if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})" left_expl = f"({left_expl})"
res_variables = [self.variable() for i in range(len(comp.ops))] res_variables = [self.variable() for i in range(len(comp.ops))]
load_names = [ast.Name(v, ast.Load()) for v in res_variables] load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables]
it = zip(range(len(comp.ops)), comp.ops, comp.comparators) it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
expls = [] expls: list[ast.expr] = []
syms = [] syms: list[ast.expr] = []
results = [left_res] results = [left_res]
for i, op, next_operand in it: for i, op, next_operand in it:
if ( if (

View File

@ -4,8 +4,7 @@ Current default behaviour is to truncate assertion explanations at
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI. terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
""" """
from typing import List from __future__ import annotations
from typing import Optional
from _pytest.assertion import util from _pytest.assertion import util
from _pytest.config import Config from _pytest.config import Config
@ -18,8 +17,8 @@ USAGE_MSG = "use '-vv' to show"
def truncate_if_required( def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None explanation: list[str], item: Item, max_length: int | None = None
) -> List[str]: ) -> list[str]:
"""Truncate this assertion explanation if the given test item is eligible.""" """Truncate this assertion explanation if the given test item is eligible."""
if _should_truncate_item(item): if _should_truncate_item(item):
return _truncate_explanation(explanation) return _truncate_explanation(explanation)
@ -33,10 +32,10 @@ def _should_truncate_item(item: Item) -> bool:
def _truncate_explanation( def _truncate_explanation(
input_lines: List[str], input_lines: list[str],
max_lines: Optional[int] = None, max_lines: int | None = None,
max_chars: Optional[int] = None, max_chars: int | None = None,
) -> List[str]: ) -> list[str]:
"""Truncate given list of strings that makes up the assertion explanation. """Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches Truncates to either 8 lines, or 640 characters - whichever the input reaches
@ -100,7 +99,7 @@ def _truncate_explanation(
] ]
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]: def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
# Find point at which input length exceeds total allowed length # Find point at which input length exceeds total allowed length
iterated_char_count = 0 iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines): for iterated_index, input_line in enumerate(input_lines):

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Utilities for assertion debugging.""" """Utilities for assertion debugging."""
from __future__ import annotations
import collections.abc import collections.abc
import os import os
import pprint import pprint
@ -8,10 +10,8 @@ from typing import AbstractSet
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Iterable from typing import Iterable
from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import Optional
from typing import Protocol from typing import Protocol
from typing import Sequence from typing import Sequence
from unicodedata import normalize from unicodedata import normalize
@ -28,14 +28,14 @@ from _pytest.config import Config
# interpretation code and assertion rewriter to detect this plugin was # interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the # loaded and in turn call the hooks defined here as part of the
# DebugInterpreter. # DebugInterpreter.
_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None _reprcompare: Callable[[str, object, object], str | None] | None = None
# Works similarly as _reprcompare attribute. Is populated with the hook call # Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called. # when pytest_runtest_setup is called.
_assertion_pass: Optional[Callable[[int, str, str], None]] = None _assertion_pass: Callable[[int, str, str], None] | None = None
# Config object which is assigned during pytest_runtest_protocol. # Config object which is assigned during pytest_runtest_protocol.
_config: Optional[Config] = None _config: Config | None = None
class _HighlightFunc(Protocol): class _HighlightFunc(Protocol):
@ -58,7 +58,7 @@ def format_explanation(explanation: str) -> str:
return "\n".join(result) return "\n".join(result)
def _split_explanation(explanation: str) -> List[str]: def _split_explanation(explanation: str) -> list[str]:
r"""Return a list of individual lines in the explanation. r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'. This will return a list of lines split on '\n{', '\n}' and '\n~'.
@ -75,7 +75,7 @@ def _split_explanation(explanation: str) -> List[str]:
return lines return lines
def _format_lines(lines: Sequence[str]) -> List[str]: def _format_lines(lines: Sequence[str]) -> list[str]:
"""Format the individual lines. """Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting This will replace the '{', '}' and '~' characters of our mini formatting
@ -169,7 +169,7 @@ def has_default_eq(
def assertrepr_compare( def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> Optional[List[str]]: ) -> list[str] | None:
"""Return specialised explanations for some operators/operands.""" """Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
@ -239,7 +239,7 @@ def assertrepr_compare(
def _compare_eq_any( def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0 left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
) -> List[str]: ) -> list[str]:
explanation = [] explanation = []
if istext(left) and istext(right): if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose) explanation = _diff_text(left, right, verbose)
@ -274,7 +274,7 @@ def _compare_eq_any(
return explanation return explanation
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: def _diff_text(left: str, right: str, verbose: int = 0) -> list[str]:
"""Return the explanation for the diff between text. """Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing Unless --verbose is used this will skip leading and trailing
@ -282,7 +282,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
""" """
from difflib import ndiff from difflib import ndiff
explanation: List[str] = [] explanation: list[str] = []
if verbose < 1: if verbose < 1:
i = 0 # just in case left or right has zero length i = 0 # just in case left or right has zero length
@ -327,7 +327,7 @@ def _compare_eq_iterable(
right: Iterable[Any], right: Iterable[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
if verbose <= 0 and not running_on_ci(): if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"] return ["Use -v to get more diff"]
# dynamic import to speedup pytest # dynamic import to speedup pytest
@ -356,9 +356,9 @@ def _compare_eq_sequence(
right: Sequence[Any], right: Sequence[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: List[str] = [] explanation: list[str] = []
len_left = len(left) len_left = len(left)
len_right = len(right) len_right = len(right)
for i in range(min(len_left, len_right)): for i in range(min(len_left, len_right)):
@ -417,7 +417,7 @@ def _compare_eq_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = [] explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
@ -429,7 +429,7 @@ def _compare_gt_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter) explanation = _compare_gte_set(left, right, highlighter)
if not explanation: if not explanation:
return ["Both sets are equal"] return ["Both sets are equal"]
@ -441,7 +441,7 @@ def _compare_lt_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter) explanation = _compare_lte_set(left, right, highlighter)
if not explanation: if not explanation:
return ["Both sets are equal"] return ["Both sets are equal"]
@ -453,7 +453,7 @@ def _compare_gte_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter) return _set_one_sided_diff("right", right, left, highlighter)
@ -462,7 +462,7 @@ def _compare_lte_set(
right: AbstractSet[Any], right: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter) return _set_one_sided_diff("left", left, right, highlighter)
@ -471,7 +471,7 @@ def _set_one_sided_diff(
set1: AbstractSet[Any], set1: AbstractSet[Any],
set2: AbstractSet[Any], set2: AbstractSet[Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
) -> List[str]: ) -> list[str]:
explanation = [] explanation = []
diff = set1 - set2 diff = set1 - set2
if diff: if diff:
@ -486,8 +486,8 @@ def _compare_eq_dict(
right: Mapping[Any, Any], right: Mapping[Any, Any],
highlighter: _HighlightFunc, highlighter: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> list[str]:
explanation: List[str] = [] explanation: list[str] = []
set_left = set(left) set_left = set(left)
set_right = set(right) set_right = set(right)
common = set_left.intersection(set_right) common = set_left.intersection(set_right)
@ -531,7 +531,7 @@ def _compare_eq_dict(
def _compare_eq_cls( def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
) -> List[str]: ) -> list[str]:
if not has_default_eq(left): if not has_default_eq(left):
return [] return []
if isdatacls(left): if isdatacls(left):
@ -584,7 +584,7 @@ def _compare_eq_cls(
return explanation return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]: def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
index = text.find(term) index = text.find(term)
head = text[:index] head = text[:index]
tail = text[index + len(term) :] tail = text[index + len(term) :]

View File

@ -3,20 +3,17 @@
# This plugin was not named "cache" to avoid conflicts with the external # This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version. # pytest-cache version.
from __future__ import annotations
import dataclasses import dataclasses
import errno import errno
import json import json
import os import os
from pathlib import Path from pathlib import Path
import tempfile import tempfile
from typing import Dict
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import Iterable from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
from typing import Union
from .pathlib import resolve_from_str from .pathlib import resolve_from_str
from .pathlib import rm_rf from .pathlib import rm_rf
@ -77,7 +74,7 @@ class Cache:
self._config = config self._config = config
@classmethod @classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> "Cache": def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
"""Create the Cache instance for a Config. """Create the Cache instance for a Config.
:meta private: :meta private:
@ -249,7 +246,7 @@ class Cache:
class LFPluginCollWrapper: class LFPluginCollWrapper:
def __init__(self, lfplugin: "LFPlugin") -> None: def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin self.lfplugin = lfplugin
self._collected_at_least_one_failure = False self._collected_at_least_one_failure = False
@ -263,7 +260,7 @@ class LFPluginCollWrapper:
lf_paths = self.lfplugin._last_failed_paths lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to prioritize last failed. # Use stable sort to prioritize last failed.
def sort_key(node: Union[nodes.Item, nodes.Collector]) -> bool: def sort_key(node: nodes.Item | nodes.Collector) -> bool:
return node.path in lf_paths return node.path in lf_paths
res.result = sorted( res.result = sorted(
@ -301,13 +298,13 @@ class LFPluginCollWrapper:
class LFPluginCollSkipfiles: class LFPluginCollSkipfiles:
def __init__(self, lfplugin: "LFPlugin") -> None: def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin self.lfplugin = lfplugin
@hookimpl @hookimpl
def pytest_make_collect_report( def pytest_make_collect_report(
self, collector: nodes.Collector self, collector: nodes.Collector
) -> Optional[CollectReport]: ) -> CollectReport | None:
if isinstance(collector, File): if isinstance(collector, File):
if collector.path not in self.lfplugin._last_failed_paths: if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1 self.lfplugin._skipped_files += 1
@ -326,9 +323,9 @@ class LFPlugin:
active_keys = "lf", "failedfirst" active_keys = "lf", "failedfirst"
self.active = any(config.getoption(key) for key in active_keys) self.active = any(config.getoption(key) for key in active_keys)
assert config.cache assert config.cache
self.lastfailed: Dict[str, bool] = config.cache.get("cache/lastfailed", {}) self.lastfailed: dict[str, bool] = config.cache.get("cache/lastfailed", {})
self._previously_failed_count: Optional[int] = None self._previously_failed_count: int | None = None
self._report_status: Optional[str] = None self._report_status: str | None = None
self._skipped_files = 0 # count skipped files during collection due to --lf self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"): if config.getoption("lf"):
@ -337,7 +334,7 @@ class LFPlugin:
LFPluginCollWrapper(self), "lfplugin-collwrapper" LFPluginCollWrapper(self), "lfplugin-collwrapper"
) )
def get_last_failed_paths(self) -> Set[Path]: def get_last_failed_paths(self) -> set[Path]:
"""Return a set with all Paths of the previously failed nodeids and """Return a set with all Paths of the previously failed nodeids and
their parents.""" their parents."""
rootpath = self.config.rootpath rootpath = self.config.rootpath
@ -348,7 +345,7 @@ class LFPlugin:
result.update(path.parents) result.update(path.parents)
return {x for x in result if x.exists()} return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> Optional[str]: def pytest_report_collectionfinish(self) -> str | None:
if self.active and self.config.getoption("verbose") >= 0: if self.active and self.config.getoption("verbose") >= 0:
return "run-last-failure: %s" % self._report_status return "run-last-failure: %s" % self._report_status
return None return None
@ -370,7 +367,7 @@ class LFPlugin:
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item] self, config: Config, items: list[nodes.Item]
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
res = yield res = yield
@ -442,13 +439,13 @@ class NFPlugin:
@hookimpl(wrapper=True, tryfirst=True) @hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
self, items: List[nodes.Item] self, items: list[nodes.Item]
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
res = yield res = yield
if self.active: if self.active:
new_items: Dict[str, nodes.Item] = {} new_items: dict[str, nodes.Item] = {}
other_items: Dict[str, nodes.Item] = {} other_items: dict[str, nodes.Item] = {}
for item in items: for item in items:
if item.nodeid not in self.cached_nodeids: if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item new_items[item.nodeid] = item
@ -464,7 +461,7 @@ class NFPlugin:
return res return res
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]: def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True) return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True)
def pytest_sessionfinish(self) -> None: def pytest_sessionfinish(self) -> None:
@ -541,7 +538,7 @@ def pytest_addoption(parser: Parser) -> None:
) )
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.cacheshow and not config.option.help: if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session from _pytest.main import wrap_session
@ -572,7 +569,7 @@ def cache(request: FixtureRequest) -> Cache:
return request.config.cache return request.config.cache
def pytest_report_header(config: Config) -> Optional[str]: def pytest_report_header(config: Config) -> str | None:
"""Display cachedir with --cache-show and if non-default.""" """Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache": if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None assert config.cache is not None

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Per-test stdout/stderr capturing mechanism.""" """Per-test stdout/stderr capturing mechanism."""
from __future__ import annotations
import abc import abc
import collections import collections
import contextlib import contextlib
@ -19,15 +21,14 @@ from typing import Generator
from typing import Generic from typing import Generic
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Literal from typing import Literal
from typing import NamedTuple from typing import NamedTuple
from typing import Optional
from typing import TextIO from typing import TextIO
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
if TYPE_CHECKING:
from typing_extensions import Self
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import hookimpl from _pytest.config import hookimpl
@ -213,7 +214,7 @@ class DontReadFromInput(TextIO):
def __next__(self) -> str: def __next__(self) -> str:
return self.readline() return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]: def readlines(self, hint: int | None = -1) -> list[str]:
raise OSError( raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`." "pytest: reading from stdin while output is captured! Consider using `-s`."
) )
@ -245,7 +246,7 @@ class DontReadFromInput(TextIO):
def tell(self) -> int: def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()") raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
def truncate(self, size: Optional[int] = None) -> int: def truncate(self, size: int | None = None) -> int:
raise UnsupportedOperation("cannot truncate stdin") raise UnsupportedOperation("cannot truncate stdin")
def write(self, data: str) -> int: def write(self, data: str) -> int:
@ -257,14 +258,14 @@ class DontReadFromInput(TextIO):
def writable(self) -> bool: def writable(self) -> bool:
return False return False
def __enter__(self) -> "DontReadFromInput": def __enter__(self) -> Self:
return self return self
def __exit__( def __exit__(
self, self,
type: Optional[Type[BaseException]], type: type[BaseException] | None,
value: Optional[BaseException], value: BaseException | None,
traceback: Optional[TracebackType], traceback: TracebackType | None,
) -> None: ) -> None:
pass pass
@ -339,7 +340,7 @@ class NoCapture(CaptureBase[str]):
class SysCaptureBase(CaptureBase[AnyStr]): class SysCaptureBase(CaptureBase[AnyStr]):
def __init__( def __init__(
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False self, fd: int, tmpfile: TextIO | None = None, *, tee: bool = False
) -> None: ) -> None:
name = patchsysdict[fd] name = patchsysdict[fd]
self._old: TextIO = getattr(sys, name) self._old: TextIO = getattr(sys, name)
@ -370,7 +371,7 @@ class SysCaptureBase(CaptureBase[AnyStr]):
self.tmpfile, self.tmpfile,
) )
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None: def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert ( assert (
self._state in states self._state in states
), "cannot {} in state {!r}: expected one of {}".format( ), "cannot {} in state {!r}: expected one of {}".format(
@ -457,7 +458,7 @@ class FDCaptureBase(CaptureBase[AnyStr]):
# Further complications are the need to support suspend() and the # Further complications are the need to support suspend() and the
# possibility of FD reuse (e.g. the tmpfile getting the very same # possibility of FD reuse (e.g. the tmpfile getting the very same
# target FD). The following approach is robust, I believe. # target FD). The following approach is robust, I believe.
self.targetfd_invalid: Optional[int] = os.open(os.devnull, os.O_RDWR) self.targetfd_invalid: int | None = os.open(os.devnull, os.O_RDWR)
os.dup2(self.targetfd_invalid, targetfd) os.dup2(self.targetfd_invalid, targetfd)
else: else:
self.targetfd_invalid = None self.targetfd_invalid = None
@ -487,7 +488,7 @@ class FDCaptureBase(CaptureBase[AnyStr]):
f"_state={self._state!r} tmpfile={self.tmpfile!r}>" f"_state={self._state!r} tmpfile={self.tmpfile!r}>"
) )
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None: def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert ( assert (
self._state in states self._state in states
), "cannot {} in state {!r}: expected one of {}".format( ), "cannot {} in state {!r}: expected one of {}".format(
@ -609,13 +610,13 @@ class MultiCapture(Generic[AnyStr]):
def __init__( def __init__(
self, self,
in_: Optional[CaptureBase[AnyStr]], in_: CaptureBase[AnyStr] | None,
out: Optional[CaptureBase[AnyStr]], out: CaptureBase[AnyStr] | None,
err: Optional[CaptureBase[AnyStr]], err: CaptureBase[AnyStr] | None,
) -> None: ) -> None:
self.in_: Optional[CaptureBase[AnyStr]] = in_ self.in_: CaptureBase[AnyStr] | None = in_
self.out: Optional[CaptureBase[AnyStr]] = out self.out: CaptureBase[AnyStr] | None = out
self.err: Optional[CaptureBase[AnyStr]] = err self.err: CaptureBase[AnyStr] | None = err
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -632,7 +633,7 @@ class MultiCapture(Generic[AnyStr]):
if self.err: if self.err:
self.err.start() self.err.start()
def pop_outerr_to_orig(self) -> Tuple[AnyStr, AnyStr]: def pop_outerr_to_orig(self) -> tuple[AnyStr, AnyStr]:
"""Pop current snapshot out/err capture and flush to orig streams.""" """Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr() out, err = self.readouterr()
if out: if out:
@ -725,8 +726,8 @@ class CaptureManager:
def __init__(self, method: _CaptureMethod) -> None: def __init__(self, method: _CaptureMethod) -> None:
self._method: Final = method self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None self._global_capturing: MultiCapture[str] | None = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None self._capture_fixture: CaptureFixture[Any] | None = None
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -734,7 +735,7 @@ class CaptureManager:
f"_capture_fixture={self._capture_fixture!r}>" f"_capture_fixture={self._capture_fixture!r}>"
) )
def is_capturing(self) -> Union[str, bool]: def is_capturing(self) -> str | bool:
if self.is_globally_capturing(): if self.is_globally_capturing():
return "global" return "global"
if self._capture_fixture: if self._capture_fixture:
@ -782,7 +783,7 @@ class CaptureManager:
# Fixture Control # Fixture Control
def set_fixture(self, capture_fixture: "CaptureFixture[Any]") -> None: def set_fixture(self, capture_fixture: CaptureFixture[Any]) -> None:
if self._capture_fixture: if self._capture_fixture:
current_fixture = self._capture_fixture.request.fixturename current_fixture = self._capture_fixture.request.fixturename
requested_fixture = capture_fixture.request.fixturename requested_fixture = capture_fixture.request.fixturename
@ -897,15 +898,15 @@ class CaptureFixture(Generic[AnyStr]):
def __init__( def __init__(
self, self,
captureclass: Type[CaptureBase[AnyStr]], captureclass: type[CaptureBase[AnyStr]],
request: SubRequest, request: SubRequest,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self.captureclass: Type[CaptureBase[AnyStr]] = captureclass self.captureclass: type[CaptureBase[AnyStr]] = captureclass
self.request = request self.request = request
self._capture: Optional[MultiCapture[AnyStr]] = None self._capture: MultiCapture[AnyStr] | None = None
self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Command line options, ini-file and conftest.py processing.""" """Command line options, ini-file and conftest.py processing."""
from __future__ import annotations
import argparse import argparse
import collections.abc import collections.abc
import copy import copy
@ -11,7 +13,7 @@ import glob
import importlib.metadata import importlib.metadata
import inspect import inspect
import os import os
from pathlib import Path import pathlib
import re import re
import shlex import shlex
import sys import sys
@ -21,22 +23,16 @@ from types import FunctionType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import Dict
from typing import Final from typing import Final
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import IO from typing import IO
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set
from typing import TextIO from typing import TextIO
from typing import Tuple
from typing import Type from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import warnings import warnings
import pluggy import pluggy
@ -74,6 +70,7 @@ if TYPE_CHECKING:
from .argparsing import Argument from .argparsing import Argument
from .argparsing import Parser from .argparsing import Parser
from _pytest._code.code import _TracebackStyle from _pytest._code.code import _TracebackStyle
from _pytest.cacheprovider import Cache
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
@ -117,7 +114,7 @@ class ExitCode(enum.IntEnum):
class ConftestImportFailure(Exception): class ConftestImportFailure(Exception):
def __init__( def __init__(
self, self,
path: Path, path: pathlib.Path,
*, *,
cause: Exception, cause: Exception,
) -> None: ) -> None:
@ -140,9 +137,9 @@ def filter_traceback_for_conftest_import_failure(
def main( def main(
args: Optional[Union[List[str], "os.PathLike[str]"]] = None, args: list[str] | os.PathLike[str] | None = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None, plugins: Sequence[str | _PluggyPlugin] | None = None,
) -> Union[int, ExitCode]: ) -> int | ExitCode:
"""Perform an in-process test run. """Perform an in-process test run.
:param args: :param args:
@ -175,9 +172,7 @@ def main(
return ExitCode.USAGE_ERROR return ExitCode.USAGE_ERROR
else: else:
try: try:
ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main( ret: ExitCode | int = config.hook.pytest_cmdline_main(config=config)
config=config
)
try: try:
return ExitCode(ret) return ExitCode(ret)
except ValueError: except ValueError:
@ -285,9 +280,9 @@ builtin_plugins.add("pytester_assertions")
def get_config( def get_config(
args: Optional[List[str]] = None, args: list[str] | None = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None, plugins: Sequence[str | _PluggyPlugin] | None = None,
) -> "Config": ) -> Config:
# subsequent calls to main will create a fresh instance # subsequent calls to main will create a fresh instance
pluginmanager = PytestPluginManager() pluginmanager = PytestPluginManager()
config = Config( config = Config(
@ -295,7 +290,7 @@ def get_config(
invocation_params=Config.InvocationParams( invocation_params=Config.InvocationParams(
args=args or (), args=args or (),
plugins=plugins, plugins=plugins,
dir=Path.cwd(), dir=pathlib.Path.cwd(),
), ),
) )
@ -309,7 +304,7 @@ def get_config(
return config return config
def get_plugin_manager() -> "PytestPluginManager": def get_plugin_manager() -> PytestPluginManager:
"""Obtain a new instance of the """Obtain a new instance of the
:py:class:`pytest.PytestPluginManager`, with default plugins :py:class:`pytest.PytestPluginManager`, with default plugins
already loaded. already loaded.
@ -321,9 +316,9 @@ def get_plugin_manager() -> "PytestPluginManager":
def _prepareconfig( def _prepareconfig(
args: Optional[Union[List[str], "os.PathLike[str]"]] = None, args: list[str] | os.PathLike[str] | None = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None, plugins: Sequence[str | _PluggyPlugin] | None = None,
) -> "Config": ) -> Config:
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
elif isinstance(args, os.PathLike): elif isinstance(args, os.PathLike):
@ -352,7 +347,7 @@ def _prepareconfig(
raise raise
def _get_directory(path: Path) -> Path: def _get_directory(path: pathlib.Path) -> pathlib.Path:
"""Get the directory of a path - itself if already a directory.""" """Get the directory of a path - itself if already a directory."""
if path.is_file(): if path.is_file():
return path.parent return path.parent
@ -363,14 +358,14 @@ def _get_directory(path: Path) -> Path:
def _get_legacy_hook_marks( def _get_legacy_hook_marks(
method: Any, method: Any,
hook_type: str, hook_type: str,
opt_names: Tuple[str, ...], opt_names: tuple[str, ...],
) -> Dict[str, bool]: ) -> dict[str, bool]:
if TYPE_CHECKING: if TYPE_CHECKING:
# abuse typeguard from importlib to avoid massive method type union thats lacking a alias # abuse typeguard from importlib to avoid massive method type union thats lacking a alias
assert inspect.isroutine(method) assert inspect.isroutine(method)
known_marks: Set[str] = {m.name for m in getattr(method, "pytestmark", [])} known_marks: set[str] = {m.name for m in getattr(method, "pytestmark", [])}
must_warn: List[str] = [] must_warn: list[str] = []
opts: Dict[str, bool] = {} opts: dict[str, bool] = {}
for opt_name in opt_names: for opt_name in opt_names:
opt_attr = getattr(method, opt_name, AttributeError) opt_attr = getattr(method, opt_name, AttributeError)
if opt_attr is not AttributeError: if opt_attr is not AttributeError:
@ -409,13 +404,13 @@ class PytestPluginManager(PluginManager):
# -- State related to local conftest plugins. # -- State related to local conftest plugins.
# All loaded conftest modules. # All loaded conftest modules.
self._conftest_plugins: Set[types.ModuleType] = set() self._conftest_plugins: set[types.ModuleType] = set()
# All conftest modules applicable for a directory. # All conftest modules applicable for a directory.
# This includes the directory's own conftest modules as well # This includes the directory's own conftest modules as well
# as those of its parent directories. # as those of its parent directories.
self._dirpath2confmods: Dict[Path, List[types.ModuleType]] = {} self._dirpath2confmods: dict[pathlib.Path, list[types.ModuleType]] = {}
# Cutoff directory above which conftests are no longer discovered. # Cutoff directory above which conftests are no longer discovered.
self._confcutdir: Optional[Path] = None self._confcutdir: pathlib.Path | None = None
# If set, conftest loading is skipped. # If set, conftest loading is skipped.
self._noconftest = False self._noconftest = False
@ -429,7 +424,7 @@ class PytestPluginManager(PluginManager):
# previously we would issue a warning when a plugin was skipped, but # previously we would issue a warning when a plugin was skipped, but
# since we refactored warnings as first citizens of Config, they are # since we refactored warnings as first citizens of Config, they are
# just stored here to be used later. # just stored here to be used later.
self.skipped_plugins: List[Tuple[str, str]] = [] self.skipped_plugins: list[tuple[str, str]] = []
self.add_hookspecs(_pytest.hookspec) self.add_hookspecs(_pytest.hookspec)
self.register(self) self.register(self)
@ -455,7 +450,7 @@ class PytestPluginManager(PluginManager):
def parse_hookimpl_opts( def parse_hookimpl_opts(
self, plugin: _PluggyPlugin, name: str self, plugin: _PluggyPlugin, name: str
) -> Optional[HookimplOpts]: ) -> HookimplOpts | None:
""":meta private:""" """:meta private:"""
# pytest hooks are always prefixed with "pytest_", # pytest hooks are always prefixed with "pytest_",
# so we avoid accessing possibly non-readable attributes # so we avoid accessing possibly non-readable attributes
@ -479,7 +474,7 @@ class PytestPluginManager(PluginManager):
method, "impl", ("tryfirst", "trylast", "optionalhook", "hookwrapper") method, "impl", ("tryfirst", "trylast", "optionalhook", "hookwrapper")
) )
def parse_hookspec_opts(self, module_or_class, name: str) -> Optional[HookspecOpts]: def parse_hookspec_opts(self, module_or_class, name: str) -> HookspecOpts | None:
""":meta private:""" """:meta private:"""
opts = super().parse_hookspec_opts(module_or_class, name) opts = super().parse_hookspec_opts(module_or_class, name)
if opts is None: if opts is None:
@ -492,9 +487,7 @@ class PytestPluginManager(PluginManager):
) )
return opts return opts
def register( def register(self, plugin: _PluggyPlugin, name: str | None = None) -> str | None:
self, plugin: _PluggyPlugin, name: Optional[str] = None
) -> Optional[str]:
if name in _pytest.deprecated.DEPRECATED_EXTERNAL_PLUGINS: if name in _pytest.deprecated.DEPRECATED_EXTERNAL_PLUGINS:
warnings.warn( warnings.warn(
PytestConfigWarning( PytestConfigWarning(
@ -521,14 +514,14 @@ class PytestPluginManager(PluginManager):
def getplugin(self, name: str): def getplugin(self, name: str):
# Support deprecated naming because plugins (xdist e.g.) use it. # Support deprecated naming because plugins (xdist e.g.) use it.
plugin: Optional[_PluggyPlugin] = self.get_plugin(name) plugin: _PluggyPlugin | None = self.get_plugin(name)
return plugin return plugin
def hasplugin(self, name: str) -> bool: def hasplugin(self, name: str) -> bool:
"""Return whether a plugin with the given name is registered.""" """Return whether a plugin with the given name is registered."""
return bool(self.get_plugin(name)) return bool(self.get_plugin(name))
def pytest_configure(self, config: "Config") -> None: def pytest_configure(self, config: Config) -> None:
""":meta private:""" """:meta private:"""
# XXX now that the pluginmanager exposes hookimpl(tryfirst...) # XXX now that the pluginmanager exposes hookimpl(tryfirst...)
# we should remove tryfirst/trylast as markers. # we should remove tryfirst/trylast as markers.
@ -551,13 +544,13 @@ class PytestPluginManager(PluginManager):
# #
def _set_initial_conftests( def _set_initial_conftests(
self, self,
args: Sequence[Union[str, Path]], args: Sequence[str | pathlib.Path],
pyargs: bool, pyargs: bool,
noconftest: bool, noconftest: bool,
rootpath: Path, rootpath: pathlib.Path,
confcutdir: Optional[Path], confcutdir: pathlib.Path | None,
invocation_dir: Path, invocation_dir: pathlib.Path,
importmode: Union[ImportMode, str], importmode: ImportMode | str,
*, *,
consider_namespace_packages: bool, consider_namespace_packages: bool,
) -> None: ) -> None:
@ -600,7 +593,7 @@ class PytestPluginManager(PluginManager):
consider_namespace_packages=consider_namespace_packages, consider_namespace_packages=consider_namespace_packages,
) )
def _is_in_confcutdir(self, path: Path) -> bool: def _is_in_confcutdir(self, path: pathlib.Path) -> bool:
"""Whether to consider the given path to load conftests from.""" """Whether to consider the given path to load conftests from."""
if self._confcutdir is None: if self._confcutdir is None:
return True return True
@ -617,9 +610,9 @@ class PytestPluginManager(PluginManager):
def _try_load_conftest( def _try_load_conftest(
self, self,
anchor: Path, anchor: pathlib.Path,
importmode: Union[str, ImportMode], importmode: str | ImportMode,
rootpath: Path, rootpath: pathlib.Path,
*, *,
consider_namespace_packages: bool, consider_namespace_packages: bool,
) -> None: ) -> None:
@ -642,9 +635,9 @@ class PytestPluginManager(PluginManager):
def _loadconftestmodules( def _loadconftestmodules(
self, self,
path: Path, path: pathlib.Path,
importmode: Union[str, ImportMode], importmode: str | ImportMode,
rootpath: Path, rootpath: pathlib.Path,
*, *,
consider_namespace_packages: bool, consider_namespace_packages: bool,
) -> None: ) -> None:
@ -672,15 +665,15 @@ class PytestPluginManager(PluginManager):
clist.append(mod) clist.append(mod)
self._dirpath2confmods[directory] = clist self._dirpath2confmods[directory] = clist
def _getconftestmodules(self, path: Path) -> Sequence[types.ModuleType]: def _getconftestmodules(self, path: pathlib.Path) -> Sequence[types.ModuleType]:
directory = self._get_directory(path) directory = self._get_directory(path)
return self._dirpath2confmods.get(directory, ()) return self._dirpath2confmods.get(directory, ())
def _rget_with_confmod( def _rget_with_confmod(
self, self,
name: str, name: str,
path: Path, path: pathlib.Path,
) -> Tuple[types.ModuleType, Any]: ) -> tuple[types.ModuleType, Any]:
modules = self._getconftestmodules(path) modules = self._getconftestmodules(path)
for mod in reversed(modules): for mod in reversed(modules):
try: try:
@ -691,9 +684,9 @@ class PytestPluginManager(PluginManager):
def _importconftest( def _importconftest(
self, self,
conftestpath: Path, conftestpath: pathlib.Path,
importmode: Union[str, ImportMode], importmode: str | ImportMode,
rootpath: Path, rootpath: pathlib.Path,
*, *,
consider_namespace_packages: bool, consider_namespace_packages: bool,
) -> types.ModuleType: ) -> types.ModuleType:
@ -745,7 +738,7 @@ class PytestPluginManager(PluginManager):
def _check_non_top_pytest_plugins( def _check_non_top_pytest_plugins(
self, self,
mod: types.ModuleType, mod: types.ModuleType,
conftestpath: Path, conftestpath: pathlib.Path,
) -> None: ) -> None:
if ( if (
hasattr(mod, "pytest_plugins") hasattr(mod, "pytest_plugins")
@ -831,7 +824,7 @@ class PytestPluginManager(PluginManager):
self._import_plugin_specs(getattr(mod, "pytest_plugins", [])) self._import_plugin_specs(getattr(mod, "pytest_plugins", []))
def _import_plugin_specs( def _import_plugin_specs(
self, spec: Union[None, types.ModuleType, str, Sequence[str]] self, spec: None | types.ModuleType | str | Sequence[str]
) -> None: ) -> None:
plugins = _get_plugin_specs_as_list(spec) plugins = _get_plugin_specs_as_list(spec)
for import_spec in plugins: for import_spec in plugins:
@ -876,8 +869,8 @@ class PytestPluginManager(PluginManager):
def _get_plugin_specs_as_list( def _get_plugin_specs_as_list(
specs: Union[None, types.ModuleType, str, Sequence[str]], specs: None | types.ModuleType | str | Sequence[str],
) -> List[str]: ) -> list[str]:
"""Parse a plugins specification into a list of plugin names.""" """Parse a plugins specification into a list of plugin names."""
# None means empty. # None means empty.
if specs is None: if specs is None:
@ -999,24 +992,27 @@ class Config:
Plugins accessing ``InvocationParams`` must be aware of that. Plugins accessing ``InvocationParams`` must be aware of that.
""" """
args: Tuple[str, ...] args: tuple[str, ...]
"""The command-line arguments as passed to :func:`pytest.main`.""" """The command-line arguments as passed to :func:`pytest.main`."""
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] plugins: Sequence[str | _PluggyPlugin] | None
"""Extra plugins, might be `None`.""" """Extra plugins, might be `None`."""
dir: Path dir: pathlib.Path
"""The directory from which :func:`pytest.main` was invoked.""" """The directory from which :func:`pytest.main` was invoked. :type: pathlib.Path"""
def __init__( def __init__(
self, self,
*, *,
args: Iterable[str], args: Iterable[str],
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]], plugins: Sequence[str | _PluggyPlugin] | None,
dir: Path, dir: pathlib.Path,
) -> None: ) -> None:
object.__setattr__(self, "args", tuple(args)) object.__setattr__(self, "args", tuple(args))
object.__setattr__(self, "plugins", plugins) object.__setattr__(self, "plugins", plugins)
object.__setattr__(self, "dir", dir) object.__setattr__(self, "dir", dir)
# Set by cacheprovider plugin.
cache: Cache
class ArgsSource(enum.Enum): class ArgsSource(enum.Enum):
"""Indicates the source of the test arguments. """Indicates the source of the test arguments.
@ -1035,14 +1031,14 @@ class Config:
self, self,
pluginmanager: PytestPluginManager, pluginmanager: PytestPluginManager,
*, *,
invocation_params: Optional[InvocationParams] = None, invocation_params: InvocationParams | None = None,
) -> None: ) -> None:
from .argparsing import FILE_OR_DIR from .argparsing import FILE_OR_DIR
from .argparsing import Parser from .argparsing import Parser
if invocation_params is None: if invocation_params is None:
invocation_params = self.InvocationParams( invocation_params = self.InvocationParams(
args=(), plugins=None, dir=Path.cwd() args=(), plugins=None, dir=pathlib.Path.cwd()
) )
self.option = argparse.Namespace() self.option = argparse.Namespace()
@ -1080,25 +1076,20 @@ class Config:
self.trace = self.pluginmanager.trace.root.get("config") self.trace = self.pluginmanager.trace.root.get("config")
self.hook: pluggy.HookRelay = PathAwareHookProxy(self.pluginmanager.hook) # type: ignore[assignment] self.hook: pluggy.HookRelay = PathAwareHookProxy(self.pluginmanager.hook) # type: ignore[assignment]
self._inicache: Dict[str, Any] = {} self._inicache: dict[str, Any] = {}
self._override_ini: Sequence[str] = () self._override_ini: Sequence[str] = ()
self._opt2dest: Dict[str, str] = {} self._opt2dest: dict[str, str] = {}
self._cleanup: List[Callable[[], None]] = [] self._cleanup: list[Callable[[], None]] = []
self.pluginmanager.register(self, "pytestconfig") self.pluginmanager.register(self, "pytestconfig")
self._configured = False self._configured = False
self.hook.pytest_addoption.call_historic( self.hook.pytest_addoption.call_historic(
kwargs=dict(parser=self._parser, pluginmanager=self.pluginmanager) kwargs=dict(parser=self._parser, pluginmanager=self.pluginmanager)
) )
self.args_source = Config.ArgsSource.ARGS self.args_source = Config.ArgsSource.ARGS
self.args: List[str] = [] self.args: list[str] = []
if TYPE_CHECKING:
from _pytest.cacheprovider import Cache
self.cache: Optional[Cache] = None
@property @property
def rootpath(self) -> Path: def rootpath(self) -> pathlib.Path:
"""The path to the :ref:`rootdir <rootdir>`. """The path to the :ref:`rootdir <rootdir>`.
:type: pathlib.Path :type: pathlib.Path
@ -1108,11 +1099,9 @@ class Config:
return self._rootpath return self._rootpath
@property @property
def inipath(self) -> Optional[Path]: def inipath(self) -> pathlib.Path | None:
"""The path to the :ref:`configfile <configfiles>`. """The path to the :ref:`configfile <configfiles>`.
:type: Optional[pathlib.Path]
.. versionadded:: 6.1 .. versionadded:: 6.1
""" """
return self._inipath return self._inipath
@ -1139,15 +1128,15 @@ class Config:
fin() fin()
def get_terminal_writer(self) -> TerminalWriter: def get_terminal_writer(self) -> TerminalWriter:
terminalreporter: Optional[TerminalReporter] = self.pluginmanager.get_plugin( terminalreporter: TerminalReporter | None = self.pluginmanager.get_plugin(
"terminalreporter" "terminalreporter"
) )
assert terminalreporter is not None assert terminalreporter is not None
return terminalreporter._tw return terminalreporter._tw
def pytest_cmdline_parse( def pytest_cmdline_parse(
self, pluginmanager: PytestPluginManager, args: List[str] self, pluginmanager: PytestPluginManager, args: list[str]
) -> "Config": ) -> Config:
try: try:
self.parse(args) self.parse(args)
except UsageError: except UsageError:
@ -1173,7 +1162,7 @@ class Config:
def notify_exception( def notify_exception(
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
option: Optional[argparse.Namespace] = None, option: argparse.Namespace | None = None,
) -> None: ) -> None:
if option and getattr(option, "fulltrace", False): if option and getattr(option, "fulltrace", False):
style: _TracebackStyle = "long" style: _TracebackStyle = "long"
@ -1196,7 +1185,7 @@ class Config:
return nodeid return nodeid
@classmethod @classmethod
def fromdictargs(cls, option_dict, args) -> "Config": def fromdictargs(cls, option_dict, args) -> Config:
"""Constructor usable for subprocesses.""" """Constructor usable for subprocesses."""
config = get_config(args) config = get_config(args)
config.option.__dict__.update(option_dict) config.option.__dict__.update(option_dict)
@ -1205,7 +1194,7 @@ class Config:
config.pluginmanager.consider_pluginarg(x) config.pluginmanager.consider_pluginarg(x)
return config return config
def _processopt(self, opt: "Argument") -> None: def _processopt(self, opt: Argument) -> None:
for name in opt._short_opts + opt._long_opts: for name in opt._short_opts + opt._long_opts:
self._opt2dest[name] = opt.dest self._opt2dest[name] = opt.dest
@ -1214,7 +1203,7 @@ class Config:
setattr(self.option, opt.dest, opt.default) setattr(self.option, opt.dest, opt.default)
@hookimpl(trylast=True) @hookimpl(trylast=True)
def pytest_load_initial_conftests(self, early_config: "Config") -> None: def pytest_load_initial_conftests(self, early_config: Config) -> None:
# We haven't fully parsed the command line arguments yet, so # We haven't fully parsed the command line arguments yet, so
# early_config.args it not set yet. But we need it for # early_config.args it not set yet. But we need it for
# discovering the initial conftests. So "pre-run" the logic here. # discovering the initial conftests. So "pre-run" the logic here.
@ -1305,7 +1294,7 @@ class Config:
for name in _iter_rewritable_modules(package_files): for name in _iter_rewritable_modules(package_files):
hook.mark_rewrite(name) hook.mark_rewrite(name)
def _validate_args(self, args: List[str], via: str) -> List[str]: def _validate_args(self, args: list[str], via: str) -> list[str]:
"""Validate known args.""" """Validate known args."""
self._parser._config_source_hint = via # type: ignore self._parser._config_source_hint = via # type: ignore
try: try:
@ -1320,13 +1309,13 @@ class Config:
def _decide_args( def _decide_args(
self, self,
*, *,
args: List[str], args: list[str],
pyargs: bool, pyargs: bool,
testpaths: List[str], testpaths: list[str],
invocation_dir: Path, invocation_dir: pathlib.Path,
rootpath: Path, rootpath: pathlib.Path,
warn: bool, warn: bool,
) -> Tuple[List[str], ArgsSource]: ) -> tuple[list[str], ArgsSource]:
"""Decide the args (initial paths/nodeids) to use given the relevant inputs. """Decide the args (initial paths/nodeids) to use given the relevant inputs.
:param warn: Whether can issue warnings. :param warn: Whether can issue warnings.
@ -1362,7 +1351,7 @@ class Config:
result = [str(invocation_dir)] result = [str(invocation_dir)]
return result, source return result, source
def _preparse(self, args: List[str], addopts: bool = True) -> None: def _preparse(self, args: list[str], addopts: bool = True) -> None:
if addopts: if addopts:
env_addopts = os.environ.get("PYTEST_ADDOPTS", "") env_addopts = os.environ.get("PYTEST_ADDOPTS", "")
if len(env_addopts): if len(env_addopts):
@ -1486,11 +1475,11 @@ class Config:
self.issue_config_time_warning(PytestConfigWarning(message), stacklevel=3) self.issue_config_time_warning(PytestConfigWarning(message), stacklevel=3)
def _get_unknown_ini_keys(self) -> List[str]: def _get_unknown_ini_keys(self) -> list[str]:
parser_inicfg = self._parser._inidict parser_inicfg = self._parser._inidict
return [name for name in self.inicfg if name not in parser_inicfg] return [name for name in self.inicfg if name not in parser_inicfg]
def parse(self, args: List[str], addopts: bool = True) -> None: def parse(self, args: list[str], addopts: bool = True) -> None:
# Parse given cmdline arguments into this config object. # Parse given cmdline arguments into this config object.
assert ( assert (
self.args == [] self.args == []
@ -1594,7 +1583,7 @@ class Config:
# Meant for easy monkeypatching by legacypath plugin. # Meant for easy monkeypatching by legacypath plugin.
# Can be inlined back (with no cover removed) once legacypath is gone. # Can be inlined back (with no cover removed) once legacypath is gone.
def _getini_unknown_type(self, name: str, type: str, value: Union[str, List[str]]): def _getini_unknown_type(self, name: str, type: str, value: str | list[str]):
msg = f"unknown configuration type: {type}" msg = f"unknown configuration type: {type}"
raise ValueError(msg, value) # pragma: no cover raise ValueError(msg, value) # pragma: no cover
@ -1650,24 +1639,26 @@ class Config:
else: else:
return self._getini_unknown_type(name, type, value) return self._getini_unknown_type(name, type, value)
def _getconftest_pathlist(self, name: str, path: Path) -> Optional[List[Path]]: def _getconftest_pathlist(
self, name: str, path: pathlib.Path
) -> list[pathlib.Path] | None:
try: try:
mod, relroots = self.pluginmanager._rget_with_confmod(name, path) mod, relroots = self.pluginmanager._rget_with_confmod(name, path)
except KeyError: except KeyError:
return None return None
assert mod.__file__ is not None assert mod.__file__ is not None
modpath = Path(mod.__file__).parent modpath = pathlib.Path(mod.__file__).parent
values: List[Path] = [] values: list[pathlib.Path] = []
for relroot in relroots: for relroot in relroots:
if isinstance(relroot, os.PathLike): if isinstance(relroot, os.PathLike):
relroot = Path(relroot) relroot = pathlib.Path(relroot)
else: else:
relroot = relroot.replace("/", os.sep) relroot = relroot.replace("/", os.sep)
relroot = absolutepath(modpath / relroot) relroot = absolutepath(modpath / relroot)
values.append(relroot) values.append(relroot)
return values return values
def _get_override_ini_value(self, name: str) -> Optional[str]: def _get_override_ini_value(self, name: str) -> str | None:
value = None value = None
# override_ini is a list of "ini=value" options. # override_ini is a list of "ini=value" options.
# Always use the last item if multiple values are set for same ini-name, # Always use the last item if multiple values are set for same ini-name,
@ -1722,7 +1713,7 @@ class Config:
VERBOSITY_TEST_CASES: Final = "test_cases" VERBOSITY_TEST_CASES: Final = "test_cases"
_VERBOSITY_INI_DEFAULT: Final = "auto" _VERBOSITY_INI_DEFAULT: Final = "auto"
def get_verbosity(self, verbosity_type: Optional[str] = None) -> int: def get_verbosity(self, verbosity_type: str | None = None) -> int:
r"""Retrieve the verbosity level for a fine-grained verbosity type. r"""Retrieve the verbosity level for a fine-grained verbosity type.
:param verbosity_type: Verbosity type to get level for. If a level is :param verbosity_type: Verbosity type to get level for. If a level is
@ -1772,7 +1763,7 @@ class Config:
return f"verbosity_{verbosity_type}" return f"verbosity_{verbosity_type}"
@staticmethod @staticmethod
def _add_verbosity_ini(parser: "Parser", verbosity_type: str, help: str) -> None: def _add_verbosity_ini(parser: Parser, verbosity_type: str, help: str) -> None:
"""Add a output verbosity configuration option for the given output type. """Add a output verbosity configuration option for the given output type.
:param parser: Parser for command line arguments and ini-file values. :param parser: Parser for command line arguments and ini-file values.
@ -1828,7 +1819,7 @@ def _assertion_supported() -> bool:
def create_terminal_writer( def create_terminal_writer(
config: Config, file: Optional[TextIO] = None config: Config, file: TextIO | None = None
) -> TerminalWriter: ) -> TerminalWriter:
"""Create a TerminalWriter instance configured according to the options """Create a TerminalWriter instance configured according to the options
in the config object. in the config object.
@ -1872,7 +1863,7 @@ def _strtobool(val: str) -> bool:
@lru_cache(maxsize=50) @lru_cache(maxsize=50)
def parse_warning_filter( def parse_warning_filter(
arg: str, *, escape: bool arg: str, *, escape: bool
) -> Tuple["warnings._ActionKind", str, Type[Warning], str, int]: ) -> tuple[warnings._ActionKind, str, type[Warning], str, int]:
"""Parse a warnings filter string. """Parse a warnings filter string.
This is copied from warnings._setoption with the following changes: This is copied from warnings._setoption with the following changes:
@ -1914,11 +1905,11 @@ def parse_warning_filter(
parts.append("") parts.append("")
action_, message, category_, module, lineno_ = (s.strip() for s in parts) action_, message, category_, module, lineno_ = (s.strip() for s in parts)
try: try:
action: "warnings._ActionKind" = warnings._getaction(action_) # type: ignore[attr-defined] action: warnings._ActionKind = warnings._getaction(action_) # type: ignore[attr-defined]
except warnings._OptionError as e: except warnings._OptionError as e:
raise UsageError(error_template.format(error=str(e))) from None raise UsageError(error_template.format(error=str(e))) from None
try: try:
category: Type[Warning] = _resolve_warning_category(category_) category: type[Warning] = _resolve_warning_category(category_)
except Exception: except Exception:
exc_info = ExceptionInfo.from_current() exc_info = ExceptionInfo.from_current()
exception_text = exc_info.getrepr(style="native") exception_text = exc_info.getrepr(style="native")
@ -1941,7 +1932,7 @@ def parse_warning_filter(
return action, message, category, module, lineno return action, message, category, module, lineno
def _resolve_warning_category(category: str) -> Type[Warning]: def _resolve_warning_category(category: str) -> type[Warning]:
""" """
Copied from warnings._getcategory, but changed so it lets exceptions (specially ImportErrors) Copied from warnings._getcategory, but changed so it lets exceptions (specially ImportErrors)
propagate so we can get access to their tracebacks (#9218). propagate so we can get access to their tracebacks (#9218).

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import argparse import argparse
from gettext import gettext from gettext import gettext
import os import os
@ -6,16 +8,12 @@ import sys
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import Dict
from typing import final from typing import final
from typing import List from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import NoReturn from typing import NoReturn
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Tuple
from typing import Union
import _pytest._io import _pytest._io
from _pytest.config.exceptions import UsageError from _pytest.config.exceptions import UsageError
@ -41,32 +39,32 @@ class Parser:
there's an error processing the command line arguments. there's an error processing the command line arguments.
""" """
prog: Optional[str] = None prog: str | None = None
def __init__( def __init__(
self, self,
usage: Optional[str] = None, usage: str | None = None,
processopt: Optional[Callable[["Argument"], None]] = None, processopt: Callable[[Argument], None] | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True) self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True)
self._groups: List[OptionGroup] = [] self._groups: list[OptionGroup] = []
self._processopt = processopt self._processopt = processopt
self._usage = usage self._usage = usage
self._inidict: Dict[str, Tuple[str, Optional[str], Any]] = {} self._inidict: dict[str, tuple[str, str | None, Any]] = {}
self._ininames: List[str] = [] self._ininames: list[str] = []
self.extra_info: Dict[str, Any] = {} self.extra_info: dict[str, Any] = {}
def processoption(self, option: "Argument") -> None: def processoption(self, option: Argument) -> None:
if self._processopt: if self._processopt:
if option.dest: if option.dest:
self._processopt(option) self._processopt(option)
def getgroup( def getgroup(
self, name: str, description: str = "", after: Optional[str] = None self, name: str, description: str = "", after: str | None = None
) -> "OptionGroup": ) -> OptionGroup:
"""Get (or create) a named option Group. """Get (or create) a named option Group.
:param name: Name of the option group. :param name: Name of the option group.
@ -108,8 +106,8 @@ class Parser:
def parse( def parse(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> argparse.Namespace: ) -> argparse.Namespace:
from _pytest._argcomplete import try_argcomplete from _pytest._argcomplete import try_argcomplete
@ -118,7 +116,7 @@ class Parser:
strargs = [os.fspath(x) for x in args] strargs = [os.fspath(x) for x in args]
return self.optparser.parse_args(strargs, namespace=namespace) return self.optparser.parse_args(strargs, namespace=namespace)
def _getparser(self) -> "MyOptionParser": def _getparser(self) -> MyOptionParser:
from _pytest._argcomplete import filescompleter from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog) optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
@ -139,10 +137,10 @@ class Parser:
def parse_setoption( def parse_setoption(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
option: argparse.Namespace, option: argparse.Namespace,
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> List[str]: ) -> list[str]:
parsedoption = self.parse(args, namespace=namespace) parsedoption = self.parse(args, namespace=namespace)
for name, value in parsedoption.__dict__.items(): for name, value in parsedoption.__dict__.items():
setattr(option, name, value) setattr(option, name, value)
@ -150,8 +148,8 @@ class Parser:
def parse_known_args( def parse_known_args(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> argparse.Namespace: ) -> argparse.Namespace:
"""Parse the known arguments at this point. """Parse the known arguments at this point.
@ -161,9 +159,9 @@ class Parser:
def parse_known_and_unknown_args( def parse_known_and_unknown_args(
self, self,
args: Sequence[Union[str, "os.PathLike[str]"]], args: Sequence[str | os.PathLike[str]],
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> Tuple[argparse.Namespace, List[str]]: ) -> tuple[argparse.Namespace, list[str]]:
"""Parse the known arguments at this point, and also return the """Parse the known arguments at this point, and also return the
remaining unknown arguments. remaining unknown arguments.
@ -179,9 +177,8 @@ class Parser:
self, self,
name: str, name: str,
help: str, help: str,
type: Optional[ type: Literal["string", "paths", "pathlist", "args", "linelist", "bool"]
Literal["string", "paths", "pathlist", "args", "linelist", "bool"] | None = None,
] = None,
default: Any = NOT_SET, default: Any = NOT_SET,
) -> None: ) -> None:
"""Register an ini-file option. """Register an ini-file option.
@ -224,7 +221,7 @@ class Parser:
def get_ini_default_for_type( def get_ini_default_for_type(
type: Optional[Literal["string", "paths", "pathlist", "args", "linelist", "bool"]], type: Literal["string", "paths", "pathlist", "args", "linelist", "bool"] | None,
) -> Any: ) -> Any:
""" """
Used by addini to get the default value for a given ini-option type, when Used by addini to get the default value for a given ini-option type, when
@ -244,7 +241,7 @@ class ArgumentError(Exception):
"""Raised if an Argument instance is created with invalid or """Raised if an Argument instance is created with invalid or
inconsistent arguments.""" inconsistent arguments."""
def __init__(self, msg: str, option: Union["Argument", str]) -> None: def __init__(self, msg: str, option: Argument | str) -> None:
self.msg = msg self.msg = msg
self.option_id = str(option) self.option_id = str(option)
@ -267,8 +264,8 @@ class Argument:
def __init__(self, *names: str, **attrs: Any) -> None: def __init__(self, *names: str, **attrs: Any) -> None:
"""Store params in private vars for use in add_argument.""" """Store params in private vars for use in add_argument."""
self._attrs = attrs self._attrs = attrs
self._short_opts: List[str] = [] self._short_opts: list[str] = []
self._long_opts: List[str] = [] self._long_opts: list[str] = []
try: try:
self.type = attrs["type"] self.type = attrs["type"]
except KeyError: except KeyError:
@ -279,7 +276,7 @@ class Argument:
except KeyError: except KeyError:
pass pass
self._set_opt_strings(names) self._set_opt_strings(names)
dest: Optional[str] = attrs.get("dest") dest: str | None = attrs.get("dest")
if dest: if dest:
self.dest = dest self.dest = dest
elif self._long_opts: elif self._long_opts:
@ -291,7 +288,7 @@ class Argument:
self.dest = "???" # Needed for the error repr. self.dest = "???" # Needed for the error repr.
raise ArgumentError("need a long or short option", self) from e raise ArgumentError("need a long or short option", self) from e
def names(self) -> List[str]: def names(self) -> list[str]:
return self._short_opts + self._long_opts return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]: def attrs(self) -> Mapping[str, Any]:
@ -335,7 +332,7 @@ class Argument:
self._long_opts.append(opt) self._long_opts.append(opt)
def __repr__(self) -> str: def __repr__(self) -> str:
args: List[str] = [] args: list[str] = []
if self._short_opts: if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)] args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts: if self._long_opts:
@ -355,14 +352,14 @@ class OptionGroup:
self, self,
name: str, name: str,
description: str = "", description: str = "",
parser: Optional[Parser] = None, parser: Parser | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self.name = name self.name = name
self.description = description self.description = description
self.options: List[Argument] = [] self.options: list[Argument] = []
self.parser = parser self.parser = parser
def addoption(self, *opts: str, **attrs: Any) -> None: def addoption(self, *opts: str, **attrs: Any) -> None:
@ -391,7 +388,7 @@ class OptionGroup:
option = Argument(*opts, **attrs) option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=True) self._addoption_instance(option, shortupper=True)
def _addoption_instance(self, option: "Argument", shortupper: bool = False) -> None: def _addoption_instance(self, option: Argument, shortupper: bool = False) -> None:
if not shortupper: if not shortupper:
for opt in option._short_opts: for opt in option._short_opts:
if opt[0] == "-" and opt[1].islower(): if opt[0] == "-" and opt[1].islower():
@ -405,8 +402,8 @@ class MyOptionParser(argparse.ArgumentParser):
def __init__( def __init__(
self, self,
parser: Parser, parser: Parser,
extra_info: Optional[Dict[str, Any]] = None, extra_info: dict[str, Any] | None = None,
prog: Optional[str] = None, prog: str | None = None,
) -> None: ) -> None:
self._parser = parser self._parser = parser
super().__init__( super().__init__(
@ -433,8 +430,8 @@ class MyOptionParser(argparse.ArgumentParser):
# Type ignored because typeshed has a very complex type in the superclass. # Type ignored because typeshed has a very complex type in the superclass.
def parse_args( # type: ignore def parse_args( # type: ignore
self, self,
args: Optional[Sequence[str]] = None, args: Sequence[str] | None = None,
namespace: Optional[argparse.Namespace] = None, namespace: argparse.Namespace | None = None,
) -> argparse.Namespace: ) -> argparse.Namespace:
"""Allow splitting of positional arguments.""" """Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace) parsed, unrecognized = self.parse_known_args(args, namespace)
@ -453,7 +450,7 @@ class MyOptionParser(argparse.ArgumentParser):
# disable long --argument abbreviations without breaking short flags. # disable long --argument abbreviations without breaking short flags.
def _parse_optional( def _parse_optional(
self, arg_string: str self, arg_string: str
) -> Optional[Tuple[Optional[argparse.Action], str, Optional[str]]]: ) -> tuple[argparse.Action | None, str, str | None] | None:
if not arg_string: if not arg_string:
return None return None
if arg_string[0] not in self.prefix_chars: if arg_string[0] not in self.prefix_chars:
@ -505,7 +502,7 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
orgstr = super()._format_action_invocation(action) orgstr = super()._format_action_invocation(action)
if orgstr and orgstr[0] != "-": # only optional arguments if orgstr and orgstr[0] != "-": # only optional arguments
return orgstr return orgstr
res: Optional[str] = getattr(action, "_formatted_action_invocation", None) res: str | None = getattr(action, "_formatted_action_invocation", None)
if res: if res:
return res return res
options = orgstr.split(", ") options = orgstr.split(", ")
@ -514,7 +511,7 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
action._formatted_action_invocation = orgstr # type: ignore action._formatted_action_invocation = orgstr # type: ignore
return orgstr return orgstr
return_list = [] return_list = []
short_long: Dict[str, str] = {} short_long: dict[str, str] = {}
for option in options: for option in options:
if len(option) == 2 or option[2] == " ": if len(option) == 2 or option[2] == " ":
continue continue

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import final from typing import final

View File

@ -1,13 +1,10 @@
from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
import sys import sys
from typing import Dict
from typing import Iterable from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Tuple
from typing import Union
import iniconfig import iniconfig
@ -32,7 +29,7 @@ def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
def load_config_dict_from_file( def load_config_dict_from_file(
filepath: Path, filepath: Path,
) -> Optional[Dict[str, Union[str, List[str]]]]: ) -> dict[str, str | list[str]] | None:
"""Load pytest configuration from the given file path, if supported. """Load pytest configuration from the given file path, if supported.
Return None if the file does not contain valid pytest configuration. Return None if the file does not contain valid pytest configuration.
@ -77,7 +74,7 @@ def load_config_dict_from_file(
# TOML supports richer data types than ini files (strings, arrays, floats, ints, etc), # TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
# however we need to convert all scalar values to str for compatibility with the rest # however we need to convert all scalar values to str for compatibility with the rest
# of the configuration system, which expects strings only. # of the configuration system, which expects strings only.
def make_scalar(v: object) -> Union[str, List[str]]: def make_scalar(v: object) -> str | list[str]:
return v if isinstance(v, list) else str(v) return v if isinstance(v, list) else str(v)
return {k: make_scalar(v) for k, v in result.items()} return {k: make_scalar(v) for k, v in result.items()}
@ -88,7 +85,7 @@ def load_config_dict_from_file(
def locate_config( def locate_config(
invocation_dir: Path, invocation_dir: Path,
args: Iterable[Path], args: Iterable[Path],
) -> Tuple[Optional[Path], Optional[Path], Dict[str, Union[str, List[str]]]]: ) -> tuple[Path | None, Path | None, dict[str, str | list[str]]]:
"""Search in the list of arguments for a valid ini-file for pytest, """Search in the list of arguments for a valid ini-file for pytest,
and return a tuple of (rootdir, inifile, cfg-dict).""" and return a tuple of (rootdir, inifile, cfg-dict)."""
config_names = [ config_names = [
@ -101,7 +98,7 @@ def locate_config(
args = [x for x in args if not str(x).startswith("-")] args = [x for x in args if not str(x).startswith("-")]
if not args: if not args:
args = [invocation_dir] args = [invocation_dir]
found_pyproject_toml: Optional[Path] = None found_pyproject_toml: Path | None = None
for arg in args: for arg in args:
argpath = absolutepath(arg) argpath = absolutepath(arg)
for base in (argpath, *argpath.parents): for base in (argpath, *argpath.parents):
@ -122,7 +119,7 @@ def get_common_ancestor(
invocation_dir: Path, invocation_dir: Path,
paths: Iterable[Path], paths: Iterable[Path],
) -> Path: ) -> Path:
common_ancestor: Optional[Path] = None common_ancestor: Path | None = None
for path in paths: for path in paths:
if not path.exists(): if not path.exists():
continue continue
@ -144,7 +141,7 @@ def get_common_ancestor(
return common_ancestor return common_ancestor
def get_dirs_from_args(args: Iterable[str]) -> List[Path]: def get_dirs_from_args(args: Iterable[str]) -> list[Path]:
def is_option(x: str) -> bool: def is_option(x: str) -> bool:
return x.startswith("-") return x.startswith("-")
@ -171,11 +168,11 @@ CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supporte
def determine_setup( def determine_setup(
*, *,
inifile: Optional[str], inifile: str | None,
args: Sequence[str], args: Sequence[str],
rootdir_cmd_arg: Optional[str], rootdir_cmd_arg: str | None,
invocation_dir: Path, invocation_dir: Path,
) -> Tuple[Path, Optional[Path], Dict[str, Union[str, List[str]]]]: ) -> tuple[Path, Path | None, dict[str, str | list[str]]]:
"""Determine the rootdir, inifile and ini configuration values from the """Determine the rootdir, inifile and ini configuration values from the
command line arguments. command line arguments.
@ -192,7 +189,7 @@ def determine_setup(
dirs = get_dirs_from_args(args) dirs = get_dirs_from_args(args)
if inifile: if inifile:
inipath_ = absolutepath(inifile) inipath_ = absolutepath(inifile)
inipath: Optional[Path] = inipath_ inipath: Path | None = inipath_
inicfg = load_config_dict_from_file(inipath_) or {} inicfg = load_config_dict_from_file(inipath_) or {}
if rootdir_cmd_arg is None: if rootdir_cmd_arg is None:
rootdir = inipath_.parent rootdir = inipath_.parent

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Interactive debugging with PDB, the Python Debugger.""" """Interactive debugging with PDB, the Python Debugger."""
from __future__ import annotations
import argparse import argparse
import functools import functools
import sys import sys
@ -8,12 +10,7 @@ import types
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import unittest import unittest
from _pytest import outcomes from _pytest import outcomes
@ -33,7 +30,7 @@ if TYPE_CHECKING:
from _pytest.runner import CallInfo from _pytest.runner import CallInfo
def _validate_usepdb_cls(value: str) -> Tuple[str, str]: def _validate_usepdb_cls(value: str) -> tuple[str, str]:
"""Validate syntax of --pdbcls option.""" """Validate syntax of --pdbcls option."""
try: try:
modname, classname = value.split(":") modname, classname = value.split(":")
@ -98,22 +95,22 @@ def pytest_configure(config: Config) -> None:
class pytestPDB: class pytestPDB:
"""Pseudo PDB that defers to the real pdb.""" """Pseudo PDB that defers to the real pdb."""
_pluginmanager: Optional[PytestPluginManager] = None _pluginmanager: PytestPluginManager | None = None
_config: Optional[Config] = None _config: Config | None = None
_saved: List[ _saved: list[
Tuple[Callable[..., None], Optional[PytestPluginManager], Optional[Config]] tuple[Callable[..., None], PytestPluginManager | None, Config | None]
] = [] ] = []
_recursive_debug = 0 _recursive_debug = 0
_wrapped_pdb_cls: Optional[Tuple[Type[Any], Type[Any]]] = None _wrapped_pdb_cls: tuple[type[Any], type[Any]] | None = None
@classmethod @classmethod
def _is_capturing(cls, capman: Optional["CaptureManager"]) -> Union[str, bool]: def _is_capturing(cls, capman: CaptureManager | None) -> str | bool:
if capman: if capman:
return capman.is_capturing() return capman.is_capturing()
return False return False
@classmethod @classmethod
def _import_pdb_cls(cls, capman: Optional["CaptureManager"]): def _import_pdb_cls(cls, capman: CaptureManager | None):
if not cls._config: if not cls._config:
import pdb import pdb
@ -152,7 +149,7 @@ class pytestPDB:
return wrapped_cls return wrapped_cls
@classmethod @classmethod
def _get_pdb_wrapper_class(cls, pdb_cls, capman: Optional["CaptureManager"]): def _get_pdb_wrapper_class(cls, pdb_cls, capman: CaptureManager | None):
import _pytest.config import _pytest.config
class PytestPdbWrapper(pdb_cls): class PytestPdbWrapper(pdb_cls):
@ -242,7 +239,7 @@ class pytestPDB:
import _pytest.config import _pytest.config
if cls._pluginmanager is None: if cls._pluginmanager is None:
capman: Optional[CaptureManager] = None capman: CaptureManager | None = None
else: else:
capman = cls._pluginmanager.getplugin("capturemanager") capman = cls._pluginmanager.getplugin("capturemanager")
if capman: if capman:
@ -285,7 +282,7 @@ class pytestPDB:
class PdbInvoke: class PdbInvoke:
def pytest_exception_interact( def pytest_exception_interact(
self, node: Node, call: "CallInfo[Any]", report: BaseReport self, node: Node, call: CallInfo[Any], report: BaseReport
) -> None: ) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager") capman = node.config.pluginmanager.getplugin("capturemanager")
if capman: if capman:

View File

@ -9,6 +9,8 @@ All constants defined in this module should be either instances of
in case of warnings which need to format their messages. in case of warnings which need to format their messages.
""" """
from __future__ import annotations
from warnings import warn from warnings import warn
from _pytest.warning_types import PytestDeprecationWarning from _pytest.warning_types import PytestDeprecationWarning

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Discover and run doctests in modules and test files.""" """Discover and run doctests in modules and test files."""
from __future__ import annotations
import bdb import bdb
from contextlib import contextmanager from contextlib import contextmanager
import functools import functools
@ -13,17 +15,11 @@ import traceback
import types import types
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict
from typing import Generator from typing import Generator
from typing import Iterable from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Sequence from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import warnings import warnings
from _pytest import outcomes from _pytest import outcomes
@ -67,7 +63,7 @@ DOCTEST_REPORT_CHOICES = (
# Lazy definition of runner class # Lazy definition of runner class
RUNNER_CLASS = None RUNNER_CLASS = None
# Lazy definition of output checker class # Lazy definition of output checker class
CHECKER_CLASS: Optional[Type["doctest.OutputChecker"]] = None CHECKER_CLASS: type[doctest.OutputChecker] | None = None
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
@ -129,7 +125,7 @@ def pytest_unconfigure() -> None:
def pytest_collect_file( def pytest_collect_file(
file_path: Path, file_path: Path,
parent: Collector, parent: Collector,
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]: ) -> DoctestModule | DoctestTextfile | None:
config = parent.config config = parent.config
if file_path.suffix == ".py": if file_path.suffix == ".py":
if config.option.doctestmodules and not any( if config.option.doctestmodules and not any(
@ -161,7 +157,7 @@ def _is_main_py(path: Path) -> bool:
class ReprFailDoctest(TerminalRepr): class ReprFailDoctest(TerminalRepr):
def __init__( def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]] self, reprlocation_lines: Sequence[tuple[ReprFileLocation, Sequence[str]]]
) -> None: ) -> None:
self.reprlocation_lines = reprlocation_lines self.reprlocation_lines = reprlocation_lines
@ -173,12 +169,12 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception): class MultipleDoctestFailures(Exception):
def __init__(self, failures: Sequence["doctest.DocTestFailure"]) -> None: def __init__(self, failures: Sequence[doctest.DocTestFailure]) -> None:
super().__init__() super().__init__()
self.failures = failures self.failures = failures
def _init_runner_class() -> Type["doctest.DocTestRunner"]: def _init_runner_class() -> type[doctest.DocTestRunner]:
import doctest import doctest
class PytestDoctestRunner(doctest.DebugRunner): class PytestDoctestRunner(doctest.DebugRunner):
@ -190,8 +186,8 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def __init__( def __init__(
self, self,
checker: Optional["doctest.OutputChecker"] = None, checker: doctest.OutputChecker | None = None,
verbose: Optional[bool] = None, verbose: bool | None = None,
optionflags: int = 0, optionflags: int = 0,
continue_on_failure: bool = True, continue_on_failure: bool = True,
) -> None: ) -> None:
@ -201,8 +197,8 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def report_failure( def report_failure(
self, self,
out, out,
test: "doctest.DocTest", test: doctest.DocTest,
example: "doctest.Example", example: doctest.Example,
got: str, got: str,
) -> None: ) -> None:
failure = doctest.DocTestFailure(test, example, got) failure = doctest.DocTestFailure(test, example, got)
@ -214,9 +210,9 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def report_unexpected_exception( def report_unexpected_exception(
self, self,
out, out,
test: "doctest.DocTest", test: doctest.DocTest,
example: "doctest.Example", example: doctest.Example,
exc_info: Tuple[Type[BaseException], BaseException, types.TracebackType], exc_info: tuple[type[BaseException], BaseException, types.TracebackType],
) -> None: ) -> None:
if isinstance(exc_info[1], OutcomeException): if isinstance(exc_info[1], OutcomeException):
raise exc_info[1] raise exc_info[1]
@ -232,11 +228,11 @@ def _init_runner_class() -> Type["doctest.DocTestRunner"]:
def _get_runner( def _get_runner(
checker: Optional["doctest.OutputChecker"] = None, checker: doctest.OutputChecker | None = None,
verbose: Optional[bool] = None, verbose: bool | None = None,
optionflags: int = 0, optionflags: int = 0,
continue_on_failure: bool = True, continue_on_failure: bool = True,
) -> "doctest.DocTestRunner": ) -> doctest.DocTestRunner:
# We need this in order to do a lazy import on doctest # We need this in order to do a lazy import on doctest
global RUNNER_CLASS global RUNNER_CLASS
if RUNNER_CLASS is None: if RUNNER_CLASS is None:
@ -255,9 +251,9 @@ class DoctestItem(Item):
def __init__( def __init__(
self, self,
name: str, name: str,
parent: "Union[DoctestTextfile, DoctestModule]", parent: DoctestTextfile | DoctestModule,
runner: "doctest.DocTestRunner", runner: doctest.DocTestRunner,
dtest: "doctest.DocTest", dtest: doctest.DocTest,
) -> None: ) -> None:
super().__init__(name, parent) super().__init__(name, parent)
self.runner = runner self.runner = runner
@ -274,18 +270,18 @@ class DoctestItem(Item):
@classmethod @classmethod
def from_parent( # type: ignore[override] def from_parent( # type: ignore[override]
cls, cls,
parent: "Union[DoctestTextfile, DoctestModule]", parent: DoctestTextfile | DoctestModule,
*, *,
name: str, name: str,
runner: "doctest.DocTestRunner", runner: doctest.DocTestRunner,
dtest: "doctest.DocTest", dtest: doctest.DocTest,
) -> "Self": ) -> Self:
# incompatible signature due to imposed limits on subclass # incompatible signature due to imposed limits on subclass
"""The public named constructor.""" """The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest) return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def _initrequest(self) -> None: def _initrequest(self) -> None:
self.funcargs: Dict[str, object] = {} self.funcargs: dict[str, object] = {}
self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type] self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type]
def setup(self) -> None: def setup(self) -> None:
@ -298,7 +294,7 @@ class DoctestItem(Item):
def runtest(self) -> None: def runtest(self) -> None:
_check_all_skipped(self.dtest) _check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin() self._disable_output_capturing_for_darwin()
failures: List["doctest.DocTestFailure"] = [] failures: list[doctest.DocTestFailure] = []
# Type ignored because we change the type of `out` from what # Type ignored because we change the type of `out` from what
# doctest expects. # doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type] self.runner.run(self.dtest, out=failures) # type: ignore[arg-type]
@ -320,12 +316,12 @@ class DoctestItem(Item):
def repr_failure( # type: ignore[override] def repr_failure( # type: ignore[override]
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
import doctest import doctest
failures: Optional[ failures: (
Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]] Sequence[doctest.DocTestFailure | doctest.UnexpectedException] | None
] = None ) = None
if isinstance( if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException) excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
): ):
@ -381,11 +377,11 @@ class DoctestItem(Item):
reprlocation_lines.append((reprlocation, lines)) reprlocation_lines.append((reprlocation, lines))
return ReprFailDoctest(reprlocation_lines) return ReprFailDoctest(reprlocation_lines)
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]: def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
return self.path, self.dtest.lineno, "[doctest] %s" % self.name return self.path, self.dtest.lineno, "[doctest] %s" % self.name
def _get_flag_lookup() -> Dict[str, int]: def _get_flag_lookup() -> dict[str, int]:
import doctest import doctest
return dict( return dict(
@ -451,7 +447,7 @@ class DoctestTextfile(Module):
) )
def _check_all_skipped(test: "doctest.DocTest") -> None: def _check_all_skipped(test: doctest.DocTest) -> None:
"""Raise pytest.skip() if all examples in the given DocTest have the SKIP """Raise pytest.skip() if all examples in the given DocTest have the SKIP
option set.""" option set."""
import doctest import doctest
@ -477,7 +473,7 @@ def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
real_unwrap = inspect.unwrap real_unwrap = inspect.unwrap
def _mock_aware_unwrap( def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None func: Callable[..., Any], *, stop: Callable[[Any], Any] | None = None
) -> Any: ) -> Any:
try: try:
if stop is None or stop is _is_mocked: if stop is None or stop is _is_mocked:
@ -588,7 +584,7 @@ class DoctestModule(Module):
) )
def _init_checker_class() -> Type["doctest.OutputChecker"]: def _init_checker_class() -> type[doctest.OutputChecker]:
import doctest import doctest
import re import re
@ -656,8 +652,8 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return got return got
offset = 0 offset = 0
for w, g in zip(wants, gots): for w, g in zip(wants, gots):
fraction: Optional[str] = w.group("fraction") fraction: str | None = w.group("fraction")
exponent: Optional[str] = w.group("exponent1") exponent: str | None = w.group("exponent1")
if exponent is None: if exponent is None:
exponent = w.group("exponent2") exponent = w.group("exponent2")
precision = 0 if fraction is None else len(fraction) precision = 0 if fraction is None else len(fraction)
@ -676,7 +672,7 @@ def _init_checker_class() -> Type["doctest.OutputChecker"]:
return LiteralsOutputChecker return LiteralsOutputChecker
def _get_checker() -> "doctest.OutputChecker": def _get_checker() -> doctest.OutputChecker:
"""Return a doctest.OutputChecker subclass that supports some """Return a doctest.OutputChecker subclass that supports some
additional options: additional options:
@ -735,7 +731,7 @@ def _get_report_choice(key: str) -> int:
@fixture(scope="session") @fixture(scope="session")
def doctest_namespace() -> Dict[str, Any]: def doctest_namespace() -> dict[str, Any]:
"""Fixture that returns a :py:class:`dict` that will be injected into the """Fixture that returns a :py:class:`dict` that will be injected into the
namespace of doctests. namespace of doctests.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
import sys import sys
from typing import Generator from typing import Generator

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import abc import abc
from collections import defaultdict from collections import defaultdict
from collections import deque from collections import deque
@ -12,21 +14,18 @@ from typing import AbstractSet
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import Dict
from typing import Final from typing import Final
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import Generic from typing import Generic
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import MutableMapping from typing import MutableMapping
from typing import NoReturn from typing import NoReturn
from typing import Optional from typing import Optional
from typing import OrderedDict from typing import OrderedDict
from typing import overload from typing import overload
from typing import Sequence from typing import Sequence
from typing import Set
from typing import Tuple from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
@ -111,18 +110,18 @@ _FixtureCachedResult = Union[
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class PseudoFixtureDef(Generic[FixtureValue]): class PseudoFixtureDef(Generic[FixtureValue]):
cached_result: "_FixtureCachedResult[FixtureValue]" cached_result: _FixtureCachedResult[FixtureValue]
_scope: Scope _scope: Scope
def pytest_sessionstart(session: "Session") -> None: def pytest_sessionstart(session: Session) -> None:
session._fixturemanager = FixtureManager(session) session._fixturemanager = FixtureManager(session)
def get_scope_package( def get_scope_package(
node: nodes.Item, node: nodes.Item,
fixturedef: "FixtureDef[object]", fixturedef: FixtureDef[object],
) -> Optional[nodes.Node]: ) -> nodes.Node | None:
from _pytest.python import Package from _pytest.python import Package
for parent in node.iter_parents(): for parent in node.iter_parents():
@ -131,7 +130,7 @@ def get_scope_package(
return node.session return node.session
def get_scope_node(node: nodes.Node, scope: Scope) -> Optional[nodes.Node]: def get_scope_node(node: nodes.Node, scope: Scope) -> nodes.Node | None:
import _pytest.python import _pytest.python
if scope is Scope.Function: if scope is Scope.Function:
@ -150,7 +149,7 @@ def get_scope_node(node: nodes.Node, scope: Scope) -> Optional[nodes.Node]:
assert_never(scope) assert_never(scope)
def getfixturemarker(obj: object) -> Optional["FixtureFunctionMarker"]: def getfixturemarker(obj: object) -> FixtureFunctionMarker | None:
"""Return fixturemarker or None if it doesn't exist or raised """Return fixturemarker or None if it doesn't exist or raised
exceptions.""" exceptions."""
return cast( return cast(
@ -163,8 +162,8 @@ def getfixturemarker(obj: object) -> Optional["FixtureFunctionMarker"]:
class FixtureArgKey: class FixtureArgKey:
argname: str argname: str
param_index: int param_index: int
scoped_item_path: Optional[Path] scoped_item_path: Path | None
item_cls: Optional[type] item_cls: type | None
def get_parametrized_fixture_keys( def get_parametrized_fixture_keys(
@ -204,10 +203,10 @@ def get_parametrized_fixture_keys(
# setups and teardowns. # setups and teardowns.
def reorder_items(items: Sequence[nodes.Item]) -> List[nodes.Item]: def reorder_items(items: Sequence[nodes.Item]) -> list[nodes.Item]:
argkeys_cache: Dict[Scope, Dict[nodes.Item, Dict[FixtureArgKey, None]]] = {} argkeys_cache: dict[Scope, dict[nodes.Item, dict[FixtureArgKey, None]]] = {}
items_by_argkey: Dict[ items_by_argkey: dict[
Scope, Dict[FixtureArgKey, OrderedDict[nodes.Item, None]] Scope, dict[FixtureArgKey, OrderedDict[nodes.Item, None]]
] = {} ] = {}
for scope in HIGH_SCOPES: for scope in HIGH_SCOPES:
scoped_argkeys_cache = argkeys_cache[scope] = {} scoped_argkeys_cache = argkeys_cache[scope] = {}
@ -226,8 +225,8 @@ def reorder_items(items: Sequence[nodes.Item]) -> List[nodes.Item]:
def fix_cache_order( def fix_cache_order(
item: nodes.Item, item: nodes.Item,
argkeys_cache: Dict[Scope, Dict[nodes.Item, Dict[FixtureArgKey, None]]], argkeys_cache: dict[Scope, dict[nodes.Item, dict[FixtureArgKey, None]]],
items_by_argkey: Dict[Scope, Dict[FixtureArgKey, OrderedDict[nodes.Item, None]]], items_by_argkey: dict[Scope, dict[FixtureArgKey, OrderedDict[nodes.Item, None]]],
) -> None: ) -> None:
for scope in HIGH_SCOPES: for scope in HIGH_SCOPES:
scoped_items_by_argkey = items_by_argkey[scope] scoped_items_by_argkey = items_by_argkey[scope]
@ -237,20 +236,20 @@ def fix_cache_order(
def reorder_items_atscope( def reorder_items_atscope(
items: Dict[nodes.Item, None], items: dict[nodes.Item, None],
argkeys_cache: Dict[Scope, Dict[nodes.Item, Dict[FixtureArgKey, None]]], argkeys_cache: dict[Scope, dict[nodes.Item, dict[FixtureArgKey, None]]],
items_by_argkey: Dict[Scope, Dict[FixtureArgKey, OrderedDict[nodes.Item, None]]], items_by_argkey: dict[Scope, dict[FixtureArgKey, OrderedDict[nodes.Item, None]]],
scope: Scope, scope: Scope,
) -> Dict[nodes.Item, None]: ) -> dict[nodes.Item, None]:
if scope is Scope.Function or len(items) < 3: if scope is Scope.Function or len(items) < 3:
return items return items
ignore: Set[Optional[FixtureArgKey]] = set() ignore: set[FixtureArgKey | None] = set()
items_deque = deque(items) items_deque = deque(items)
items_done: Dict[nodes.Item, None] = {} items_done: dict[nodes.Item, None] = {}
scoped_items_by_argkey = items_by_argkey[scope] scoped_items_by_argkey = items_by_argkey[scope]
scoped_argkeys_cache = argkeys_cache[scope] scoped_argkeys_cache = argkeys_cache[scope]
while items_deque: while items_deque:
no_argkey_group: Dict[nodes.Item, None] = {} no_argkey_group: dict[nodes.Item, None] = {}
slicing_argkey = None slicing_argkey = None
while items_deque: while items_deque:
item = items_deque.popleft() item = items_deque.popleft()
@ -299,19 +298,19 @@ class FuncFixtureInfo:
__slots__ = ("argnames", "initialnames", "names_closure", "name2fixturedefs") __slots__ = ("argnames", "initialnames", "names_closure", "name2fixturedefs")
# Fixture names that the item requests directly by function parameters. # Fixture names that the item requests directly by function parameters.
argnames: Tuple[str, ...] argnames: tuple[str, ...]
# Fixture names that the item immediately requires. These include # Fixture names that the item immediately requires. These include
# argnames + fixture names specified via usefixtures and via autouse=True in # argnames + fixture names specified via usefixtures and via autouse=True in
# fixture definitions. # fixture definitions.
initialnames: Tuple[str, ...] initialnames: tuple[str, ...]
# The transitive closure of the fixture names that the item requires. # The transitive closure of the fixture names that the item requires.
# Note: can't include dynamic dependencies (`request.getfixturevalue` calls). # Note: can't include dynamic dependencies (`request.getfixturevalue` calls).
names_closure: List[str] names_closure: list[str]
# A map from a fixture name in the transitive closure to the FixtureDefs # A map from a fixture name in the transitive closure to the FixtureDefs
# matching the name which are applicable to this function. # matching the name which are applicable to this function.
# There may be multiple overriding fixtures with the same name. The # There may be multiple overriding fixtures with the same name. The
# sequence is ordered from furthest to closes to the function. # sequence is ordered from furthest to closes to the function.
name2fixturedefs: Dict[str, Sequence["FixtureDef[Any]"]] name2fixturedefs: dict[str, Sequence[FixtureDef[Any]]]
def prune_dependency_tree(self) -> None: def prune_dependency_tree(self) -> None:
"""Recompute names_closure from initialnames and name2fixturedefs. """Recompute names_closure from initialnames and name2fixturedefs.
@ -324,7 +323,7 @@ class FuncFixtureInfo:
tree. In this way the dependency tree can get pruned, and the closure tree. In this way the dependency tree can get pruned, and the closure
of argnames may get reduced. of argnames may get reduced.
""" """
closure: Set[str] = set() closure: set[str] = set()
working_set = set(self.initialnames) working_set = set(self.initialnames)
while working_set: while working_set:
argname = working_set.pop() argname = working_set.pop()
@ -350,10 +349,10 @@ class FixtureRequest(abc.ABC):
def __init__( def __init__(
self, self,
pyfuncitem: "Function", pyfuncitem: Function,
fixturename: Optional[str], fixturename: str | None,
arg2fixturedefs: Dict[str, Sequence["FixtureDef[Any]"]], arg2fixturedefs: dict[str, Sequence[FixtureDef[Any]]],
fixture_defs: Dict[str, "FixtureDef[Any]"], fixture_defs: dict[str, FixtureDef[Any]],
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -380,7 +379,7 @@ class FixtureRequest(abc.ABC):
self.param: Any self.param: Any
@property @property
def _fixturemanager(self) -> "FixtureManager": def _fixturemanager(self) -> FixtureManager:
return self._pyfuncitem.session._fixturemanager return self._pyfuncitem.session._fixturemanager
@property @property
@ -396,13 +395,13 @@ class FixtureRequest(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def _check_scope( def _check_scope(
self, self,
requested_fixturedef: Union["FixtureDef[object]", PseudoFixtureDef[object]], requested_fixturedef: FixtureDef[object] | PseudoFixtureDef[object],
requested_scope: Scope, requested_scope: Scope,
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@property @property
def fixturenames(self) -> List[str]: def fixturenames(self) -> list[str]:
"""Names of all active fixtures in this request.""" """Names of all active fixtures in this request."""
result = list(self._pyfuncitem.fixturenames) result = list(self._pyfuncitem.fixturenames)
result.extend(set(self._fixture_defs).difference(result)) result.extend(set(self._fixture_defs).difference(result))
@ -467,7 +466,7 @@ class FixtureRequest(abc.ABC):
return node.keywords return node.keywords
@property @property
def session(self) -> "Session": def session(self) -> Session:
"""Pytest session object.""" """Pytest session object."""
return self._pyfuncitem.session return self._pyfuncitem.session
@ -477,7 +476,7 @@ class FixtureRequest(abc.ABC):
the last test within the requesting test context finished execution.""" the last test within the requesting test context finished execution."""
raise NotImplementedError() raise NotImplementedError()
def applymarker(self, marker: Union[str, MarkDecorator]) -> None: def applymarker(self, marker: str | MarkDecorator) -> None:
"""Apply a marker to a single test function invocation. """Apply a marker to a single test function invocation.
This method is useful if you don't want to have a keyword/marker This method is useful if you don't want to have a keyword/marker
@ -488,7 +487,7 @@ class FixtureRequest(abc.ABC):
""" """
self.node.add_marker(marker) self.node.add_marker(marker)
def raiseerror(self, msg: Optional[str]) -> NoReturn: def raiseerror(self, msg: str | None) -> NoReturn:
"""Raise a FixtureLookupError exception. """Raise a FixtureLookupError exception.
:param msg: :param msg:
@ -525,7 +524,7 @@ class FixtureRequest(abc.ABC):
) )
return fixturedef.cached_result[0] return fixturedef.cached_result[0]
def _iter_chain(self) -> Iterator["SubRequest"]: def _iter_chain(self) -> Iterator[SubRequest]:
"""Yield all SubRequests in the chain, from self up. """Yield all SubRequests in the chain, from self up.
Note: does *not* yield the TopRequest. Note: does *not* yield the TopRequest.
@ -537,7 +536,7 @@ class FixtureRequest(abc.ABC):
def _get_active_fixturedef( def _get_active_fixturedef(
self, argname: str self, argname: str
) -> Union["FixtureDef[object]", PseudoFixtureDef[object]]: ) -> FixtureDef[object] | PseudoFixtureDef[object]:
if argname == "request": if argname == "request":
cached_result = (self, [0], None) cached_result = (self, [0], None)
return PseudoFixtureDef(cached_result, Scope.Function) return PseudoFixtureDef(cached_result, Scope.Function)
@ -608,7 +607,7 @@ class FixtureRequest(abc.ABC):
self._fixture_defs[argname] = fixturedef self._fixture_defs[argname] = fixturedef
return fixturedef return fixturedef
def _check_fixturedef_without_param(self, fixturedef: "FixtureDef[object]") -> None: def _check_fixturedef_without_param(self, fixturedef: FixtureDef[object]) -> None:
"""Check that this request is allowed to execute this fixturedef without """Check that this request is allowed to execute this fixturedef without
a param.""" a param."""
funcitem = self._pyfuncitem funcitem = self._pyfuncitem
@ -641,7 +640,7 @@ class FixtureRequest(abc.ABC):
) )
fail(msg, pytrace=False) fail(msg, pytrace=False)
def _get_fixturestack(self) -> List["FixtureDef[Any]"]: def _get_fixturestack(self) -> list[FixtureDef[Any]]:
values = [request._fixturedef for request in self._iter_chain()] values = [request._fixturedef for request in self._iter_chain()]
values.reverse() values.reverse()
return values return values
@ -651,7 +650,7 @@ class FixtureRequest(abc.ABC):
class TopRequest(FixtureRequest): class TopRequest(FixtureRequest):
"""The type of the ``request`` fixture in a test function.""" """The type of the ``request`` fixture in a test function."""
def __init__(self, pyfuncitem: "Function", *, _ispytest: bool = False) -> None: def __init__(self, pyfuncitem: Function, *, _ispytest: bool = False) -> None:
super().__init__( super().__init__(
fixturename=None, fixturename=None,
pyfuncitem=pyfuncitem, pyfuncitem=pyfuncitem,
@ -666,7 +665,7 @@ class TopRequest(FixtureRequest):
def _check_scope( def _check_scope(
self, self,
requested_fixturedef: Union["FixtureDef[object]", PseudoFixtureDef[object]], requested_fixturedef: FixtureDef[object] | PseudoFixtureDef[object],
requested_scope: Scope, requested_scope: Scope,
) -> None: ) -> None:
# TopRequest always has function scope so always valid. # TopRequest always has function scope so always valid.
@ -700,7 +699,7 @@ class SubRequest(FixtureRequest):
scope: Scope, scope: Scope,
param: Any, param: Any,
param_index: int, param_index: int,
fixturedef: "FixtureDef[object]", fixturedef: FixtureDef[object],
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -730,7 +729,7 @@ class SubRequest(FixtureRequest):
scope = self._scope scope = self._scope
if scope is Scope.Function: if scope is Scope.Function:
# This might also be a non-function Item despite its attribute name. # This might also be a non-function Item despite its attribute name.
node: Optional[nodes.Node] = self._pyfuncitem node: nodes.Node | None = self._pyfuncitem
elif scope is Scope.Package: elif scope is Scope.Package:
node = get_scope_package(self._pyfuncitem, self._fixturedef) node = get_scope_package(self._pyfuncitem, self._fixturedef)
else: else:
@ -743,7 +742,7 @@ class SubRequest(FixtureRequest):
def _check_scope( def _check_scope(
self, self,
requested_fixturedef: Union["FixtureDef[object]", PseudoFixtureDef[object]], requested_fixturedef: FixtureDef[object] | PseudoFixtureDef[object],
requested_scope: Scope, requested_scope: Scope,
) -> None: ) -> None:
if isinstance(requested_fixturedef, PseudoFixtureDef): if isinstance(requested_fixturedef, PseudoFixtureDef):
@ -764,7 +763,7 @@ class SubRequest(FixtureRequest):
pytrace=False, pytrace=False,
) )
def _format_fixturedef_line(self, fixturedef: "FixtureDef[object]") -> str: def _format_fixturedef_line(self, fixturedef: FixtureDef[object]) -> str:
factory = fixturedef.func factory = fixturedef.func
path, lineno = getfslineno(factory) path, lineno = getfslineno(factory)
if isinstance(path, Path): if isinstance(path, Path):
@ -781,15 +780,15 @@ class FixtureLookupError(LookupError):
"""Could not return a requested fixture (missing or invalid).""" """Could not return a requested fixture (missing or invalid)."""
def __init__( def __init__(
self, argname: Optional[str], request: FixtureRequest, msg: Optional[str] = None self, argname: str | None, request: FixtureRequest, msg: str | None = None
) -> None: ) -> None:
self.argname = argname self.argname = argname
self.request = request self.request = request
self.fixturestack = request._get_fixturestack() self.fixturestack = request._get_fixturestack()
self.msg = msg self.msg = msg
def formatrepr(self) -> "FixtureLookupErrorRepr": def formatrepr(self) -> FixtureLookupErrorRepr:
tblines: List[str] = [] tblines: list[str] = []
addline = tblines.append addline = tblines.append
stack = [self.request._pyfuncitem.obj] stack = [self.request._pyfuncitem.obj]
stack.extend(map(lambda x: x.func, self.fixturestack)) stack.extend(map(lambda x: x.func, self.fixturestack))
@ -837,11 +836,11 @@ class FixtureLookupError(LookupError):
class FixtureLookupErrorRepr(TerminalRepr): class FixtureLookupErrorRepr(TerminalRepr):
def __init__( def __init__(
self, self,
filename: Union[str, "os.PathLike[str]"], filename: str | os.PathLike[str],
firstlineno: int, firstlineno: int,
tblines: Sequence[str], tblines: Sequence[str],
errorstring: str, errorstring: str,
argname: Optional[str], argname: str | None,
) -> None: ) -> None:
self.tblines = tblines self.tblines = tblines
self.errorstring = errorstring self.errorstring = errorstring
@ -869,7 +868,7 @@ class FixtureLookupErrorRepr(TerminalRepr):
def call_fixture_func( def call_fixture_func(
fixturefunc: "_FixtureFunc[FixtureValue]", request: FixtureRequest, kwargs fixturefunc: _FixtureFunc[FixtureValue], request: FixtureRequest, kwargs
) -> FixtureValue: ) -> FixtureValue:
if is_generator(fixturefunc): if is_generator(fixturefunc):
fixturefunc = cast( fixturefunc = cast(
@ -940,14 +939,12 @@ class FixtureDef(Generic[FixtureValue]):
def __init__( def __init__(
self, self,
config: Config, config: Config,
baseid: Optional[str], baseid: str | None,
argname: str, argname: str,
func: "_FixtureFunc[FixtureValue]", func: _FixtureFunc[FixtureValue],
scope: Union[Scope, _ScopeName, Callable[[str, Config], _ScopeName], None], scope: Scope | _ScopeName | Callable[[str, Config], _ScopeName] | None,
params: Optional[Sequence[object]], params: Sequence[object] | None,
ids: Optional[ ids: tuple[object | None, ...] | Callable[[Any], object | None] | None = None,
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]]
] = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -993,8 +990,8 @@ class FixtureDef(Generic[FixtureValue]):
self.argnames: Final = getfuncargnames(func, name=argname) self.argnames: Final = getfuncargnames(func, name=argname)
# If the fixture was executed, the current value of the fixture. # If the fixture was executed, the current value of the fixture.
# Can change if the fixture is executed with different parameters. # Can change if the fixture is executed with different parameters.
self.cached_result: Optional[_FixtureCachedResult[FixtureValue]] = None self.cached_result: _FixtureCachedResult[FixtureValue] | None = None
self._finalizers: Final[List[Callable[[], object]]] = [] self._finalizers: Final[list[Callable[[], object]]] = []
@property @property
def scope(self) -> _ScopeName: def scope(self) -> _ScopeName:
@ -1005,7 +1002,7 @@ class FixtureDef(Generic[FixtureValue]):
self._finalizers.append(finalizer) self._finalizers.append(finalizer)
def finish(self, request: SubRequest) -> None: def finish(self, request: SubRequest) -> None:
exceptions: List[BaseException] = [] exceptions: list[BaseException] = []
while self._finalizers: while self._finalizers:
fin = self._finalizers.pop() fin = self._finalizers.pop()
try: try:
@ -1089,7 +1086,7 @@ class FixtureDef(Generic[FixtureValue]):
def resolve_fixture_function( def resolve_fixture_function(
fixturedef: FixtureDef[FixtureValue], request: FixtureRequest fixturedef: FixtureDef[FixtureValue], request: FixtureRequest
) -> "_FixtureFunc[FixtureValue]": ) -> _FixtureFunc[FixtureValue]:
"""Get the actual callable that can be called to obtain the fixture """Get the actual callable that can be called to obtain the fixture
value.""" value."""
fixturefunc = fixturedef.func fixturefunc = fixturedef.func
@ -1137,7 +1134,7 @@ def pytest_fixture_setup(
def wrap_function_to_error_out_if_called_directly( def wrap_function_to_error_out_if_called_directly(
function: FixtureFunction, function: FixtureFunction,
fixture_marker: "FixtureFunctionMarker", fixture_marker: FixtureFunctionMarker,
) -> FixtureFunction: ) -> FixtureFunction:
"""Wrap the given fixture function so we can raise an error about it being called directly, """Wrap the given fixture function so we can raise an error about it being called directly,
instead of used as an argument in a test function.""" instead of used as an argument in a test function."""
@ -1163,13 +1160,11 @@ def wrap_function_to_error_out_if_called_directly(
@final @final
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FixtureFunctionMarker: class FixtureFunctionMarker:
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" scope: _ScopeName | Callable[[str, Config], _ScopeName]
params: Optional[Tuple[object, ...]] params: tuple[object, ...] | None
autouse: bool = False autouse: bool = False
ids: Optional[ ids: tuple[object | None, ...] | Callable[[Any], object | None] | None = None
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]] name: str | None = None
] = None
name: Optional[str] = None
_ispytest: dataclasses.InitVar[bool] = False _ispytest: dataclasses.InitVar[bool] = False
@ -1207,13 +1202,11 @@ class FixtureFunctionMarker:
def fixture( def fixture(
fixture_function: FixtureFunction, fixture_function: FixtureFunction,
*, *,
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ..., scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
params: Optional[Iterable[object]] = ..., params: Iterable[object] | None = ...,
autouse: bool = ..., autouse: bool = ...,
ids: Optional[ ids: Sequence[object | None] | Callable[[Any], object | None] | None = ...,
Union[Sequence[Optional[object]], Callable[[Any], Optional[object]]] name: str | None = ...,
] = ...,
name: Optional[str] = ...,
) -> FixtureFunction: ... ) -> FixtureFunction: ...
@ -1221,27 +1214,23 @@ def fixture(
def fixture( def fixture(
fixture_function: None = ..., fixture_function: None = ...,
*, *,
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ..., scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
params: Optional[Iterable[object]] = ..., params: Iterable[object] | None = ...,
autouse: bool = ..., autouse: bool = ...,
ids: Optional[ ids: Sequence[object | None] | Callable[[Any], object | None] | None = ...,
Union[Sequence[Optional[object]], Callable[[Any], Optional[object]]] name: str | None = None,
] = ...,
name: Optional[str] = None,
) -> FixtureFunctionMarker: ... ) -> FixtureFunctionMarker: ...
def fixture( def fixture(
fixture_function: Optional[FixtureFunction] = None, fixture_function: FixtureFunction | None = None,
*, *,
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = "function", scope: _ScopeName | Callable[[str, Config], _ScopeName] = "function",
params: Optional[Iterable[object]] = None, params: Iterable[object] | None = None,
autouse: bool = False, autouse: bool = False,
ids: Optional[ ids: Sequence[object | None] | Callable[[Any], object | None] | None = None,
Union[Sequence[Optional[object]], Callable[[Any], Optional[object]]] name: str | None = None,
] = None, ) -> FixtureFunctionMarker | FixtureFunction:
name: Optional[str] = None,
) -> Union[FixtureFunctionMarker, FixtureFunction]:
"""Decorator to mark a fixture factory function. """Decorator to mark a fixture factory function.
This decorator can be used, with or without parameters, to define a This decorator can be used, with or without parameters, to define a
@ -1375,7 +1364,7 @@ def pytest_addoption(parser: Parser) -> None:
) )
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.showfixtures: if config.option.showfixtures:
showfixtures(config) showfixtures(config)
return 0 return 0
@ -1385,7 +1374,7 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
return None return None
def _get_direct_parametrize_args(node: nodes.Node) -> Set[str]: def _get_direct_parametrize_args(node: nodes.Node) -> set[str]:
"""Return all direct parametrization arguments of a node, so we don't """Return all direct parametrization arguments of a node, so we don't
mistake them for fixtures. mistake them for fixtures.
@ -1394,7 +1383,7 @@ def _get_direct_parametrize_args(node: nodes.Node) -> Set[str]:
These things are done later as well when dealing with parametrization These things are done later as well when dealing with parametrization
so this could be improved. so this could be improved.
""" """
parametrize_argnames: Set[str] = set() parametrize_argnames: set[str] = set()
for marker in node.iter_markers(name="parametrize"): for marker in node.iter_markers(name="parametrize"):
if not marker.kwargs.get("indirect", False): if not marker.kwargs.get("indirect", False):
p_argnames, _ = ParameterSet._parse_parametrize_args( p_argnames, _ = ParameterSet._parse_parametrize_args(
@ -1404,7 +1393,7 @@ def _get_direct_parametrize_args(node: nodes.Node) -> Set[str]:
return parametrize_argnames return parametrize_argnames
def deduplicate_names(*seqs: Iterable[str]) -> Tuple[str, ...]: def deduplicate_names(*seqs: Iterable[str]) -> tuple[str, ...]:
"""De-duplicate the sequence of names while keeping the original order.""" """De-duplicate the sequence of names while keeping the original order."""
# Ideally we would use a set, but it does not preserve insertion order. # Ideally we would use a set, but it does not preserve insertion order.
return tuple(dict.fromkeys(name for seq in seqs for name in seq)) return tuple(dict.fromkeys(name for seq in seqs for name in seq))
@ -1441,17 +1430,17 @@ class FixtureManager:
by a lookup of their FuncFixtureInfo. by a lookup of their FuncFixtureInfo.
""" """
def __init__(self, session: "Session") -> None: def __init__(self, session: Session) -> None:
self.session = session self.session = session
self.config: Config = session.config self.config: Config = session.config
# Maps a fixture name (argname) to all of the FixtureDefs in the test # Maps a fixture name (argname) to all of the FixtureDefs in the test
# suite/plugins defined with this name. Populated by parsefactories(). # suite/plugins defined with this name. Populated by parsefactories().
# TODO: The order of the FixtureDefs list of each arg is significant, # TODO: The order of the FixtureDefs list of each arg is significant,
# explain. # explain.
self._arg2fixturedefs: Final[Dict[str, List[FixtureDef[Any]]]] = {} self._arg2fixturedefs: Final[dict[str, list[FixtureDef[Any]]]] = {}
self._holderobjseen: Final[Set[object]] = set() self._holderobjseen: Final[set[object]] = set()
# A mapping from a nodeid to a list of autouse fixtures it defines. # A mapping from a nodeid to a list of autouse fixtures it defines.
self._nodeid_autousenames: Final[Dict[str, List[str]]] = { self._nodeid_autousenames: Final[dict[str, list[str]]] = {
"": self.config.getini("usefixtures"), "": self.config.getini("usefixtures"),
} }
session.config.pluginmanager.register(self, "funcmanage") session.config.pluginmanager.register(self, "funcmanage")
@ -1459,8 +1448,8 @@ class FixtureManager:
def getfixtureinfo( def getfixtureinfo(
self, self,
node: nodes.Item, node: nodes.Item,
func: Optional[Callable[..., object]], func: Callable[..., object] | None,
cls: Optional[type], cls: type | None,
) -> FuncFixtureInfo: ) -> FuncFixtureInfo:
"""Calculate the :class:`FuncFixtureInfo` for an item. """Calculate the :class:`FuncFixtureInfo` for an item.
@ -1531,9 +1520,9 @@ class FixtureManager:
def getfixtureclosure( def getfixtureclosure(
self, self,
parentnode: nodes.Node, parentnode: nodes.Node,
initialnames: Tuple[str, ...], initialnames: tuple[str, ...],
ignore_args: AbstractSet[str], ignore_args: AbstractSet[str],
) -> Tuple[List[str], Dict[str, Sequence[FixtureDef[Any]]]]: ) -> tuple[list[str], dict[str, Sequence[FixtureDef[Any]]]]:
# Collect the closure of all fixtures, starting with the given # Collect the closure of all fixtures, starting with the given
# fixturenames as the initial set. As we have to visit all # fixturenames as the initial set. As we have to visit all
# factory definitions anyway, we also return an arg2fixturedefs # factory definitions anyway, we also return an arg2fixturedefs
@ -1543,7 +1532,7 @@ class FixtureManager:
fixturenames_closure = list(initialnames) fixturenames_closure = list(initialnames)
arg2fixturedefs: Dict[str, Sequence[FixtureDef[Any]]] = {} arg2fixturedefs: dict[str, Sequence[FixtureDef[Any]]] = {}
lastlen = -1 lastlen = -1
while lastlen != len(fixturenames_closure): while lastlen != len(fixturenames_closure):
lastlen = len(fixturenames_closure) lastlen = len(fixturenames_closure)
@ -1570,7 +1559,7 @@ class FixtureManager:
fixturenames_closure.sort(key=sort_by_scope, reverse=True) fixturenames_closure.sort(key=sort_by_scope, reverse=True)
return fixturenames_closure, arg2fixturedefs return fixturenames_closure, arg2fixturedefs
def pytest_generate_tests(self, metafunc: "Metafunc") -> None: def pytest_generate_tests(self, metafunc: Metafunc) -> None:
"""Generate new tests based on parametrized fixtures used by the given metafunc""" """Generate new tests based on parametrized fixtures used by the given metafunc"""
def get_parametrize_mark_argnames(mark: Mark) -> Sequence[str]: def get_parametrize_mark_argnames(mark: Mark) -> Sequence[str]:
@ -1615,7 +1604,7 @@ class FixtureManager:
# Try next super fixture, if any. # Try next super fixture, if any.
def pytest_collection_modifyitems(self, items: List[nodes.Item]) -> None: def pytest_collection_modifyitems(self, items: list[nodes.Item]) -> None:
# Separate parametrized setups. # Separate parametrized setups.
items[:] = reorder_items(items) items[:] = reorder_items(items)
@ -1623,15 +1612,14 @@ class FixtureManager:
self, self,
*, *,
name: str, name: str,
func: "_FixtureFunc[object]", func: _FixtureFunc[object],
nodeid: Optional[str], nodeid: str | None,
scope: Union[ scope: Scope
Scope, _ScopeName, Callable[[str, Config], _ScopeName], None | _ScopeName
] = "function", | Callable[[str, Config], _ScopeName]
params: Optional[Sequence[object]] = None, | None = "function",
ids: Optional[ params: Sequence[object] | None = None,
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]] ids: tuple[object | None, ...] | Callable[[Any], object | None] | None = None,
] = None,
autouse: bool = False, autouse: bool = False,
) -> None: ) -> None:
"""Register a fixture """Register a fixture
@ -1689,14 +1677,14 @@ class FixtureManager:
def parsefactories( def parsefactories(
self, self,
node_or_obj: object, node_or_obj: object,
nodeid: Optional[str], nodeid: str | None,
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
def parsefactories( def parsefactories(
self, self,
node_or_obj: Union[nodes.Node, object], node_or_obj: nodes.Node | object,
nodeid: Union[str, NotSetType, None] = NOTSET, nodeid: str | NotSetType | None = NOTSET,
) -> None: ) -> None:
"""Collect fixtures from a collection node or object. """Collect fixtures from a collection node or object.
@ -1754,7 +1742,7 @@ class FixtureManager:
def getfixturedefs( def getfixturedefs(
self, argname: str, node: nodes.Node self, argname: str, node: nodes.Node
) -> Optional[Sequence[FixtureDef[Any]]]: ) -> Sequence[FixtureDef[Any]] | None:
"""Get FixtureDefs for a fixture name which are applicable """Get FixtureDefs for a fixture name which are applicable
to a given node. to a given node.
@ -1781,7 +1769,7 @@ class FixtureManager:
yield fixturedef yield fixturedef
def show_fixtures_per_test(config: Config) -> Union[int, ExitCode]: def show_fixtures_per_test(config: Config) -> int | ExitCode:
from _pytest.main import wrap_session from _pytest.main import wrap_session
return wrap_session(config, _show_fixtures_per_test) return wrap_session(config, _show_fixtures_per_test)
@ -1799,7 +1787,7 @@ def _pretty_fixture_path(invocation_dir: Path, func) -> str:
return bestrelpath(invocation_dir, loc) return bestrelpath(invocation_dir, loc)
def _show_fixtures_per_test(config: Config, session: "Session") -> None: def _show_fixtures_per_test(config: Config, session: Session) -> None:
import _pytest.config import _pytest.config
session.perform_collect() session.perform_collect()
@ -1829,7 +1817,7 @@ def _show_fixtures_per_test(config: Config, session: "Session") -> None:
def write_item(item: nodes.Item) -> None: def write_item(item: nodes.Item) -> None:
# Not all items have _fixtureinfo attribute. # Not all items have _fixtureinfo attribute.
info: Optional[FuncFixtureInfo] = getattr(item, "_fixtureinfo", None) info: FuncFixtureInfo | None = getattr(item, "_fixtureinfo", None)
if info is None or not info.name2fixturedefs: if info is None or not info.name2fixturedefs:
# This test item does not use any fixtures. # This test item does not use any fixtures.
return return
@ -1849,13 +1837,13 @@ def _show_fixtures_per_test(config: Config, session: "Session") -> None:
write_item(session_item) write_item(session_item)
def showfixtures(config: Config) -> Union[int, ExitCode]: def showfixtures(config: Config) -> int | ExitCode:
from _pytest.main import wrap_session from _pytest.main import wrap_session
return wrap_session(config, _showfixtures_main) return wrap_session(config, _showfixtures_main)
def _showfixtures_main(config: Config, session: "Session") -> None: def _showfixtures_main(config: Config, session: Session) -> None:
import _pytest.config import _pytest.config
session.perform_collect() session.perform_collect()
@ -1866,7 +1854,7 @@ def _showfixtures_main(config: Config, session: "Session") -> None:
fm = session._fixturemanager fm = session._fixturemanager
available = [] available = []
seen: Set[Tuple[str, str]] = set() seen: set[tuple[str, str]] = set()
for argname, fixturedefs in fm._arg2fixturedefs.items(): for argname, fixturedefs in fm._arg2fixturedefs.items():
assert fixturedefs is not None assert fixturedefs is not None

View File

@ -1,13 +1,13 @@
"""Provides a function to report all internal modules for using freezing """Provides a function to report all internal modules for using freezing
tools.""" tools."""
from __future__ import annotations
import types import types
from typing import Iterator from typing import Iterator
from typing import List
from typing import Union
def freeze_includes() -> List[str]: def freeze_includes() -> list[str]:
"""Return a list of module names used by pytest that should be """Return a list of module names used by pytest that should be
included by cx_freeze.""" included by cx_freeze."""
import _pytest import _pytest
@ -17,7 +17,7 @@ def freeze_includes() -> List[str]:
def _iter_all_modules( def _iter_all_modules(
package: Union[str, types.ModuleType], package: str | types.ModuleType,
prefix: str = "", prefix: str = "",
) -> Iterator[str]: ) -> Iterator[str]:
"""Iterate over the names of all modules that can be found in the given """Iterate over the names of all modules that can be found in the given

View File

@ -1,13 +1,12 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Version info, help messages, tracing configuration.""" """Version info, help messages, tracing configuration."""
from __future__ import annotations
from argparse import Action from argparse import Action
import os import os
import sys import sys
from typing import Generator from typing import Generator
from typing import List
from typing import Optional
from typing import Union
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
@ -147,7 +146,7 @@ def showversion(config: Config) -> None:
sys.stdout.write(f"pytest {pytest.__version__}\n") sys.stdout.write(f"pytest {pytest.__version__}\n")
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.version > 0: if config.option.version > 0:
showversion(config) showversion(config)
return 0 return 0
@ -162,7 +161,7 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
def showhelp(config: Config) -> None: def showhelp(config: Config) -> None:
import textwrap import textwrap
reporter: Optional[TerminalReporter] = config.pluginmanager.get_plugin( reporter: TerminalReporter | None = config.pluginmanager.get_plugin(
"terminalreporter" "terminalreporter"
) )
assert reporter is not None assert reporter is not None
@ -239,7 +238,7 @@ def showhelp(config: Config) -> None:
conftest_options = [("pytest_plugins", "list of plugin names to load")] conftest_options = [("pytest_plugins", "list of plugin names to load")]
def getpluginversioninfo(config: Config) -> List[str]: def getpluginversioninfo(config: Config) -> list[str]:
lines = [] lines = []
plugininfo = config.pluginmanager.list_plugin_distinfo() plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo: if plugininfo:
@ -251,7 +250,7 @@ def getpluginversioninfo(config: Config) -> List[str]:
return lines return lines
def pytest_report_header(config: Config) -> List[str]: def pytest_report_header(config: Config) -> list[str]:
lines = [] lines = []
if config.option.debug or config.option.traceconfig: if config.option.debug or config.option.traceconfig:
lines.append(f"using: pytest-{pytest.__version__}") lines.append(f"using: pytest-{pytest.__version__}")

View File

@ -2,16 +2,13 @@
"""Hook specifications for pytest plugins which are invoked by pytest itself """Hook specifications for pytest plugins which are invoked by pytest itself
and by builtin plugins.""" and by builtin plugins."""
from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Dict
from typing import List
from typing import Mapping from typing import Mapping
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
from pluggy import HookspecMarker from pluggy import HookspecMarker
@ -56,7 +53,7 @@ hookspec = HookspecMarker("pytest")
@hookspec(historic=True) @hookspec(historic=True)
def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None: def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Called at plugin registration time to allow adding new hooks via a call to """Called at plugin registration time to allow adding new hooks via a call to
:func:`pluginmanager.add_hookspecs(module_or_class, prefix) <pytest.PytestPluginManager.add_hookspecs>`. :func:`pluginmanager.add_hookspecs(module_or_class, prefix) <pytest.PytestPluginManager.add_hookspecs>`.
@ -75,9 +72,9 @@ def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
@hookspec(historic=True) @hookspec(historic=True)
def pytest_plugin_registered( def pytest_plugin_registered(
plugin: "_PluggyPlugin", plugin: _PluggyPlugin,
plugin_name: str, plugin_name: str,
manager: "PytestPluginManager", manager: PytestPluginManager,
) -> None: ) -> None:
"""A new pytest plugin got registered. """A new pytest plugin got registered.
@ -99,7 +96,7 @@ def pytest_plugin_registered(
@hookspec(historic=True) @hookspec(historic=True)
def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") -> None: def pytest_addoption(parser: Parser, pluginmanager: PytestPluginManager) -> None:
"""Register argparse-style options and ini-style config values, """Register argparse-style options and ini-style config values,
called once at the beginning of a test run. called once at the beginning of a test run.
@ -140,7 +137,7 @@ def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") ->
@hookspec(historic=True) @hookspec(historic=True)
def pytest_configure(config: "Config") -> None: def pytest_configure(config: Config) -> None:
"""Allow plugins and conftest files to perform initial configuration. """Allow plugins and conftest files to perform initial configuration.
.. note:: .. note::
@ -165,8 +162,8 @@ def pytest_configure(config: "Config") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_cmdline_parse( def pytest_cmdline_parse(
pluginmanager: "PytestPluginManager", args: List[str] pluginmanager: PytestPluginManager, args: list[str]
) -> Optional["Config"]: ) -> Config | None:
"""Return an initialized :class:`~pytest.Config`, parsing the specified args. """Return an initialized :class:`~pytest.Config`, parsing the specified args.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -188,7 +185,7 @@ def pytest_cmdline_parse(
def pytest_load_initial_conftests( def pytest_load_initial_conftests(
early_config: "Config", parser: "Parser", args: List[str] early_config: Config, parser: Parser, args: list[str]
) -> None: ) -> None:
"""Called to implement the loading of :ref:`initial conftest files """Called to implement the loading of :ref:`initial conftest files
<pluginorder>` ahead of command line option parsing. <pluginorder>` ahead of command line option parsing.
@ -205,7 +202,7 @@ def pytest_load_initial_conftests(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]: def pytest_cmdline_main(config: Config) -> ExitCode | int | None:
"""Called for performing the main command line action. """Called for performing the main command line action.
The default implementation will invoke the configure hooks and The default implementation will invoke the configure hooks and
@ -229,7 +226,7 @@ def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_collection(session: "Session") -> Optional[object]: def pytest_collection(session: Session) -> object | None:
"""Perform the collection phase for the given session. """Perform the collection phase for the given session.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -271,7 +268,7 @@ def pytest_collection(session: "Session") -> Optional[object]:
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
session: "Session", config: "Config", items: List["Item"] session: Session, config: Config, items: list[Item]
) -> None: ) -> None:
"""Called after collection has been performed. May filter or re-order """Called after collection has been performed. May filter or re-order
the items in-place. the items in-place.
@ -287,7 +284,7 @@ def pytest_collection_modifyitems(
""" """
def pytest_collection_finish(session: "Session") -> None: def pytest_collection_finish(session: Session) -> None:
"""Called after collection has been performed and modified. """Called after collection has been performed and modified.
:param session: The pytest session object. :param session: The pytest session object.
@ -308,8 +305,8 @@ def pytest_collection_finish(session: "Session") -> None:
}, },
) )
def pytest_ignore_collect( def pytest_ignore_collect(
collection_path: Path, path: "LEGACY_PATH", config: "Config" collection_path: Path, path: LEGACY_PATH, config: Config
) -> Optional[bool]: ) -> bool | None:
"""Return ``True`` to ignore this path for collection. """Return ``True`` to ignore this path for collection.
Return ``None`` to let other plugins ignore the path for collection. Return ``None`` to let other plugins ignore the path for collection.
@ -323,6 +320,7 @@ def pytest_ignore_collect(
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
:param collection_path: The path to analyze. :param collection_path: The path to analyze.
:type collection_path: pathlib.Path
:param path: The path to analyze (deprecated). :param path: The path to analyze (deprecated).
:param config: The pytest config object. :param config: The pytest config object.
@ -342,7 +340,7 @@ def pytest_ignore_collect(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Collector]": def pytest_collect_directory(path: Path, parent: Collector) -> Collector | None:
"""Create a :class:`~pytest.Collector` for the given directory, or None if """Create a :class:`~pytest.Collector` for the given directory, or None if
not relevant. not relevant.
@ -356,6 +354,7 @@ def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Colle
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
:param path: The path to analyze. :param path: The path to analyze.
:type path: pathlib.Path
See :ref:`custom directory collectors` for a simple example of use of this See :ref:`custom directory collectors` for a simple example of use of this
hook. hook.
@ -378,8 +377,8 @@ def pytest_collect_directory(path: Path, parent: "Collector") -> "Optional[Colle
}, },
) )
def pytest_collect_file( def pytest_collect_file(
file_path: Path, path: "LEGACY_PATH", parent: "Collector" file_path: Path, path: LEGACY_PATH, parent: Collector
) -> "Optional[Collector]": ) -> Collector | None:
"""Create a :class:`~pytest.Collector` for the given path, or None if not relevant. """Create a :class:`~pytest.Collector` for the given path, or None if not relevant.
For best results, the returned collector should be a subclass of For best results, the returned collector should be a subclass of
@ -388,6 +387,7 @@ def pytest_collect_file(
The new node needs to have the specified ``parent`` as a parent. The new node needs to have the specified ``parent`` as a parent.
:param file_path: The path to analyze. :param file_path: The path to analyze.
:type file_path: pathlib.Path
:param path: The path to collect (deprecated). :param path: The path to collect (deprecated).
.. versionchanged:: 7.0.0 .. versionchanged:: 7.0.0
@ -406,7 +406,7 @@ def pytest_collect_file(
# logging hooks for collection # logging hooks for collection
def pytest_collectstart(collector: "Collector") -> None: def pytest_collectstart(collector: Collector) -> None:
"""Collector starts collecting. """Collector starts collecting.
:param collector: :param collector:
@ -421,7 +421,7 @@ def pytest_collectstart(collector: "Collector") -> None:
""" """
def pytest_itemcollected(item: "Item") -> None: def pytest_itemcollected(item: Item) -> None:
"""We just collected a test item. """We just collected a test item.
:param item: :param item:
@ -435,7 +435,7 @@ def pytest_itemcollected(item: "Item") -> None:
""" """
def pytest_collectreport(report: "CollectReport") -> None: def pytest_collectreport(report: CollectReport) -> None:
"""Collector finished collecting. """Collector finished collecting.
:param report: :param report:
@ -450,7 +450,7 @@ def pytest_collectreport(report: "CollectReport") -> None:
""" """
def pytest_deselected(items: Sequence["Item"]) -> None: def pytest_deselected(items: Sequence[Item]) -> None:
"""Called for deselected test items, e.g. by keyword. """Called for deselected test items, e.g. by keyword.
May be called multiple times. May be called multiple times.
@ -466,7 +466,7 @@ def pytest_deselected(items: Sequence["Item"]) -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectReport]": def pytest_make_collect_report(collector: Collector) -> CollectReport | None:
"""Perform :func:`collector.collect() <pytest.Collector.collect>` and return """Perform :func:`collector.collect() <pytest.Collector.collect>` and return
a :class:`~pytest.CollectReport`. a :class:`~pytest.CollectReport`.
@ -498,8 +498,8 @@ def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectRepor
}, },
) )
def pytest_pycollect_makemodule( def pytest_pycollect_makemodule(
module_path: Path, path: "LEGACY_PATH", parent module_path: Path, path: LEGACY_PATH, parent
) -> Optional["Module"]: ) -> Module | None:
"""Return a :class:`pytest.Module` collector or None for the given path. """Return a :class:`pytest.Module` collector or None for the given path.
This hook will be called for each matching test module path. This hook will be called for each matching test module path.
@ -509,6 +509,7 @@ def pytest_pycollect_makemodule(
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
:param module_path: The path of the module to collect. :param module_path: The path of the module to collect.
:type module_path: pathlib.Path
:param path: The path of the module to collect (deprecated). :param path: The path of the module to collect (deprecated).
.. versionchanged:: 7.0.0 .. versionchanged:: 7.0.0
@ -528,8 +529,8 @@ def pytest_pycollect_makemodule(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: Union["Module", "Class"], name: str, obj: object collector: Module | Class, name: str, obj: object
) -> Union[None, "Item", "Collector", List[Union["Item", "Collector"]]]: ) -> None | Item | Collector | list[Item | Collector]:
"""Return a custom item/collector for a Python object in a module, or None. """Return a custom item/collector for a Python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -553,7 +554,7 @@ def pytest_pycollect_makeitem(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]: def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
"""Call underlying test function. """Call underlying test function.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -570,7 +571,7 @@ def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
""" """
def pytest_generate_tests(metafunc: "Metafunc") -> None: def pytest_generate_tests(metafunc: Metafunc) -> None:
"""Generate (multiple) parametrized calls to a test function. """Generate (multiple) parametrized calls to a test function.
:param metafunc: :param metafunc:
@ -586,9 +587,7 @@ def pytest_generate_tests(metafunc: "Metafunc") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_make_parametrize_id( def pytest_make_parametrize_id(config: Config, val: object, argname: str) -> str | None:
config: "Config", val: object, argname: str
) -> Optional[str]:
"""Return a user-friendly string representation of the given ``val`` """Return a user-friendly string representation of the given ``val``
that will be used by @pytest.mark.parametrize calls, or None if the hook that will be used by @pytest.mark.parametrize calls, or None if the hook
doesn't know about ``val``. doesn't know about ``val``.
@ -614,7 +613,7 @@ def pytest_make_parametrize_id(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_runtestloop(session: "Session") -> Optional[object]: def pytest_runtestloop(session: Session) -> object | None:
"""Perform the main runtest loop (after collection finished). """Perform the main runtest loop (after collection finished).
The default hook implementation performs the runtest protocol for all items The default hook implementation performs the runtest protocol for all items
@ -640,9 +639,7 @@ def pytest_runtestloop(session: "Session") -> Optional[object]:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_runtest_protocol( def pytest_runtest_protocol(item: Item, nextitem: Item | None) -> object | None:
item: "Item", nextitem: "Optional[Item]"
) -> Optional[object]:
"""Perform the runtest protocol for a single test item. """Perform the runtest protocol for a single test item.
The default runtest protocol is this (see individual hooks for full details): The default runtest protocol is this (see individual hooks for full details):
@ -682,9 +679,7 @@ def pytest_runtest_protocol(
""" """
def pytest_runtest_logstart( def pytest_runtest_logstart(nodeid: str, location: tuple[str, int | None, str]) -> None:
nodeid: str, location: Tuple[str, Optional[int], str]
) -> None:
"""Called at the start of running the runtest protocol for a single item. """Called at the start of running the runtest protocol for a single item.
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol. See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
@ -703,7 +698,7 @@ def pytest_runtest_logstart(
def pytest_runtest_logfinish( def pytest_runtest_logfinish(
nodeid: str, location: Tuple[str, Optional[int], str] nodeid: str, location: tuple[str, int | None, str]
) -> None: ) -> None:
"""Called at the end of running the runtest protocol for a single item. """Called at the end of running the runtest protocol for a single item.
@ -722,7 +717,7 @@ def pytest_runtest_logfinish(
""" """
def pytest_runtest_setup(item: "Item") -> None: def pytest_runtest_setup(item: Item) -> None:
"""Called to perform the setup phase for a test item. """Called to perform the setup phase for a test item.
The default implementation runs ``setup()`` on ``item`` and all of its The default implementation runs ``setup()`` on ``item`` and all of its
@ -741,7 +736,7 @@ def pytest_runtest_setup(item: "Item") -> None:
""" """
def pytest_runtest_call(item: "Item") -> None: def pytest_runtest_call(item: Item) -> None:
"""Called to run the test for test item (the call phase). """Called to run the test for test item (the call phase).
The default implementation calls ``item.runtest()``. The default implementation calls ``item.runtest()``.
@ -757,7 +752,7 @@ def pytest_runtest_call(item: "Item") -> None:
""" """
def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None: def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
"""Called to perform the teardown phase for a test item. """Called to perform the teardown phase for a test item.
The default implementation runs the finalizers and calls ``teardown()`` The default implementation runs the finalizers and calls ``teardown()``
@ -782,9 +777,7 @@ def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_runtest_makereport( def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport | None:
item: "Item", call: "CallInfo[None]"
) -> Optional["TestReport"]:
"""Called to create a :class:`~pytest.TestReport` for each of """Called to create a :class:`~pytest.TestReport` for each of
the setup, call and teardown runtest phases of a test item. the setup, call and teardown runtest phases of a test item.
@ -803,7 +796,7 @@ def pytest_runtest_makereport(
""" """
def pytest_runtest_logreport(report: "TestReport") -> None: def pytest_runtest_logreport(report: TestReport) -> None:
"""Process the :class:`~pytest.TestReport` produced for each """Process the :class:`~pytest.TestReport` produced for each
of the setup, call and teardown runtest phases of an item. of the setup, call and teardown runtest phases of an item.
@ -819,9 +812,9 @@ def pytest_runtest_logreport(report: "TestReport") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_to_serializable( def pytest_report_to_serializable(
config: "Config", config: Config,
report: Union["CollectReport", "TestReport"], report: CollectReport | TestReport,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
"""Serialize the given report object into a data structure suitable for """Serialize the given report object into a data structure suitable for
sending over the wire, e.g. converted to JSON. sending over the wire, e.g. converted to JSON.
@ -838,9 +831,9 @@ def pytest_report_to_serializable(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_from_serializable( def pytest_report_from_serializable(
config: "Config", config: Config,
data: Dict[str, Any], data: dict[str, Any],
) -> Optional[Union["CollectReport", "TestReport"]]: ) -> CollectReport | TestReport | None:
"""Restore a report object previously serialized with """Restore a report object previously serialized with
:hook:`pytest_report_to_serializable`. :hook:`pytest_report_to_serializable`.
@ -861,8 +854,8 @@ def pytest_report_from_serializable(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_fixture_setup( def pytest_fixture_setup(
fixturedef: "FixtureDef[Any]", request: "SubRequest" fixturedef: FixtureDef[Any], request: SubRequest
) -> Optional[object]: ) -> object | None:
"""Perform fixture setup execution. """Perform fixture setup execution.
:param fixturedef: :param fixturedef:
@ -889,7 +882,7 @@ def pytest_fixture_setup(
def pytest_fixture_post_finalizer( def pytest_fixture_post_finalizer(
fixturedef: "FixtureDef[Any]", request: "SubRequest" fixturedef: FixtureDef[Any], request: SubRequest
) -> None: ) -> None:
"""Called after fixture teardown, but before the cache is cleared, so """Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not the fixture result ``fixturedef.cached_result`` is still available (not
@ -914,7 +907,7 @@ def pytest_fixture_post_finalizer(
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def pytest_sessionstart(session: "Session") -> None: def pytest_sessionstart(session: Session) -> None:
"""Called after the ``Session`` object has been created and before performing collection """Called after the ``Session`` object has been created and before performing collection
and entering the run test loop. and entering the run test loop.
@ -928,8 +921,8 @@ def pytest_sessionstart(session: "Session") -> None:
def pytest_sessionfinish( def pytest_sessionfinish(
session: "Session", session: Session,
exitstatus: Union[int, "ExitCode"], exitstatus: int | ExitCode,
) -> None: ) -> None:
"""Called after whole test run finished, right before returning the exit status to the system. """Called after whole test run finished, right before returning the exit status to the system.
@ -943,7 +936,7 @@ def pytest_sessionfinish(
""" """
def pytest_unconfigure(config: "Config") -> None: def pytest_unconfigure(config: Config) -> None:
"""Called before test process is exited. """Called before test process is exited.
:param config: The pytest config object. :param config: The pytest config object.
@ -961,8 +954,8 @@ def pytest_unconfigure(config: "Config") -> None:
def pytest_assertrepr_compare( def pytest_assertrepr_compare(
config: "Config", op: str, left: object, right: object config: Config, op: str, left: object, right: object
) -> Optional[List[str]]: ) -> list[str] | None:
"""Return explanation for comparisons in failing assert expressions. """Return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list Return None for no custom explanation, otherwise return a list
@ -983,7 +976,7 @@ def pytest_assertrepr_compare(
""" """
def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> None: def pytest_assertion_pass(item: Item, lineno: int, orig: str, expl: str) -> None:
"""Called whenever an assertion passes. """Called whenever an assertion passes.
.. versionadded:: 5.0 .. versionadded:: 5.0
@ -1030,12 +1023,13 @@ def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> No
}, },
) )
def pytest_report_header( # type:ignore[empty-body] def pytest_report_header( # type:ignore[empty-body]
config: "Config", start_path: Path, startdir: "LEGACY_PATH" config: Config, start_path: Path, startdir: LEGACY_PATH
) -> Union[str, List[str]]: ) -> str | list[str]:
"""Return a string or list of strings to be displayed as header info for terminal reporting. """Return a string or list of strings to be displayed as header info for terminal reporting.
:param config: The pytest config object. :param config: The pytest config object.
:param start_path: The starting dir. :param start_path: The starting dir.
:type start_path: pathlib.Path
:param startdir: The starting dir (deprecated). :param startdir: The starting dir (deprecated).
.. note:: .. note::
@ -1065,11 +1059,11 @@ def pytest_report_header( # type:ignore[empty-body]
}, },
) )
def pytest_report_collectionfinish( # type:ignore[empty-body] def pytest_report_collectionfinish( # type:ignore[empty-body]
config: "Config", config: Config,
start_path: Path, start_path: Path,
startdir: "LEGACY_PATH", startdir: LEGACY_PATH,
items: Sequence["Item"], items: Sequence[Item],
) -> Union[str, List[str]]: ) -> str | list[str]:
"""Return a string or list of strings to be displayed after collection """Return a string or list of strings to be displayed after collection
has finished successfully. has finished successfully.
@ -1079,6 +1073,7 @@ def pytest_report_collectionfinish( # type:ignore[empty-body]
:param config: The pytest config object. :param config: The pytest config object.
:param start_path: The starting dir. :param start_path: The starting dir.
:type start_path: pathlib.Path
:param startdir: The starting dir (deprecated). :param startdir: The starting dir (deprecated).
:param items: List of pytest items that are going to be executed; this list should not be modified. :param items: List of pytest items that are going to be executed; this list should not be modified.
@ -1103,8 +1098,8 @@ def pytest_report_collectionfinish( # type:ignore[empty-body]
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_teststatus( # type:ignore[empty-body] def pytest_report_teststatus( # type:ignore[empty-body]
report: Union["CollectReport", "TestReport"], config: "Config" report: CollectReport | TestReport, config: Config
) -> "TestShortLogReport | Tuple[str, str, Union[str, Tuple[str, Mapping[str, bool]]]]": ) -> TestShortLogReport | tuple[str, str, str | tuple[str, Mapping[str, bool]]]:
"""Return result-category, shortletter and verbose word for status """Return result-category, shortletter and verbose word for status
reporting. reporting.
@ -1135,9 +1130,9 @@ def pytest_report_teststatus( # type:ignore[empty-body]
def pytest_terminal_summary( def pytest_terminal_summary(
terminalreporter: "TerminalReporter", terminalreporter: TerminalReporter,
exitstatus: "ExitCode", exitstatus: ExitCode,
config: "Config", config: Config,
) -> None: ) -> None:
"""Add a section to terminal summary reporting. """Add a section to terminal summary reporting.
@ -1157,10 +1152,10 @@ def pytest_terminal_summary(
@hookspec(historic=True) @hookspec(historic=True)
def pytest_warning_recorded( def pytest_warning_recorded(
warning_message: "warnings.WarningMessage", warning_message: warnings.WarningMessage,
when: "Literal['config', 'collect', 'runtest']", when: Literal["config", "collect", "runtest"],
nodeid: str, nodeid: str,
location: Optional[Tuple[str, int, str]], location: tuple[str, int, str] | None,
) -> None: ) -> None:
"""Process a warning captured by the internal pytest warnings plugin. """Process a warning captured by the internal pytest warnings plugin.
@ -1201,8 +1196,8 @@ def pytest_warning_recorded(
def pytest_markeval_namespace( # type:ignore[empty-body] def pytest_markeval_namespace( # type:ignore[empty-body]
config: "Config", config: Config,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Called when constructing the globals dictionary used for """Called when constructing the globals dictionary used for
evaluating string conditions in xfail/skipif markers. evaluating string conditions in xfail/skipif markers.
@ -1230,9 +1225,9 @@ def pytest_markeval_namespace( # type:ignore[empty-body]
def pytest_internalerror( def pytest_internalerror(
excrepr: "ExceptionRepr", excrepr: ExceptionRepr,
excinfo: "ExceptionInfo[BaseException]", excinfo: ExceptionInfo[BaseException],
) -> Optional[bool]: ) -> bool | None:
"""Called for internal errors. """Called for internal errors.
Return True to suppress the fallback handling of printing an Return True to suppress the fallback handling of printing an
@ -1249,7 +1244,7 @@ def pytest_internalerror(
def pytest_keyboard_interrupt( def pytest_keyboard_interrupt(
excinfo: "ExceptionInfo[Union[KeyboardInterrupt, Exit]]", excinfo: ExceptionInfo[KeyboardInterrupt | Exit],
) -> None: ) -> None:
"""Called for keyboard interrupt. """Called for keyboard interrupt.
@ -1263,9 +1258,9 @@ def pytest_keyboard_interrupt(
def pytest_exception_interact( def pytest_exception_interact(
node: Union["Item", "Collector"], node: Item | Collector,
call: "CallInfo[Any]", call: CallInfo[Any],
report: Union["CollectReport", "TestReport"], report: CollectReport | TestReport,
) -> None: ) -> None:
"""Called when an exception was raised which can potentially be """Called when an exception was raised which can potentially be
interactively handled. interactively handled.
@ -1294,7 +1289,7 @@ def pytest_exception_interact(
""" """
def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None: def pytest_enter_pdb(config: Config, pdb: pdb.Pdb) -> None:
"""Called upon pdb.set_trace(). """Called upon pdb.set_trace().
Can be used by plugins to take special action just before the python Can be used by plugins to take special action just before the python
@ -1310,7 +1305,7 @@ def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
""" """
def pytest_leave_pdb(config: "Config", pdb: "pdb.Pdb") -> None: def pytest_leave_pdb(config: Config, pdb: pdb.Pdb) -> None:
"""Called when leaving pdb (e.g. with continue after pdb.set_trace()). """Called when leaving pdb (e.g. with continue after pdb.set_trace()).
Can be used by plugins to take special action just after the python Can be used by plugins to take special action just after the python

View File

@ -8,18 +8,15 @@ Output conforms to
https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
""" """
from __future__ import annotations
from datetime import datetime from datetime import datetime
import functools import functools
import os import os
import platform import platform
import re import re
from typing import Callable from typing import Callable
from typing import Dict
from typing import List
from typing import Match from typing import Match
from typing import Optional
from typing import Tuple
from typing import Union
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from _pytest import nodes from _pytest import nodes
@ -89,15 +86,15 @@ families["xunit2"] = families["_base"]
class _NodeReporter: class _NodeReporter:
def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None: def __init__(self, nodeid: str | TestReport, xml: LogXML) -> None:
self.id = nodeid self.id = nodeid
self.xml = xml self.xml = xml
self.add_stats = self.xml.add_stats self.add_stats = self.xml.add_stats
self.family = self.xml.family self.family = self.xml.family
self.duration = 0.0 self.duration = 0.0
self.properties: List[Tuple[str, str]] = [] self.properties: list[tuple[str, str]] = []
self.nodes: List[ET.Element] = [] self.nodes: list[ET.Element] = []
self.attrs: Dict[str, str] = {} self.attrs: dict[str, str] = {}
def append(self, node: ET.Element) -> None: def append(self, node: ET.Element) -> None:
self.xml.add_stats(node.tag) self.xml.add_stats(node.tag)
@ -109,7 +106,7 @@ class _NodeReporter:
def add_attribute(self, name: str, value: object) -> None: def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value) self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self) -> Optional[ET.Element]: def make_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any.""" """Return a Junit node containing custom properties, if any."""
if self.properties: if self.properties:
properties = ET.Element("properties") properties = ET.Element("properties")
@ -124,7 +121,7 @@ class _NodeReporter:
classnames = names[:-1] classnames = names[:-1]
if self.xml.prefix: if self.xml.prefix:
classnames.insert(0, self.xml.prefix) classnames.insert(0, self.xml.prefix)
attrs: Dict[str, str] = { attrs: dict[str, str] = {
"classname": ".".join(classnames), "classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]), "name": bin_xml_escape(names[-1]),
"file": testreport.location[0], "file": testreport.location[0],
@ -156,7 +153,7 @@ class _NodeReporter:
testcase.extend(self.nodes) testcase.extend(self.nodes)
return testcase return testcase
def _add_simple(self, tag: str, message: str, data: Optional[str] = None) -> None: def _add_simple(self, tag: str, message: str, data: str | None = None) -> None:
node = ET.Element(tag, message=message) node = ET.Element(tag, message=message)
node.text = bin_xml_escape(data) node.text = bin_xml_escape(data)
self.append(node) self.append(node)
@ -201,7 +198,7 @@ class _NodeReporter:
self._add_simple("skipped", "xfail-marked test passes unexpectedly") self._add_simple("skipped", "xfail-marked test passes unexpectedly")
else: else:
assert report.longrepr is not None assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr( reprcrash: ReprFileLocation | None = getattr(
report.longrepr, "reprcrash", None report.longrepr, "reprcrash", None
) )
if reprcrash is not None: if reprcrash is not None:
@ -221,9 +218,7 @@ class _NodeReporter:
def append_error(self, report: TestReport) -> None: def append_error(self, report: TestReport) -> None:
assert report.longrepr is not None assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr( reprcrash: ReprFileLocation | None = getattr(report.longrepr, "reprcrash", None)
report.longrepr, "reprcrash", None
)
if reprcrash is not None: if reprcrash is not None:
reason = reprcrash.message reason = reprcrash.message
else: else:
@ -451,7 +446,7 @@ def pytest_unconfigure(config: Config) -> None:
config.pluginmanager.unregister(xml) config.pluginmanager.unregister(xml)
def mangle_test_address(address: str) -> List[str]: def mangle_test_address(address: str) -> list[str]:
path, possible_open_bracket, params = address.partition("[") path, possible_open_bracket, params = address.partition("[")
names = path.split("::") names = path.split("::")
# Convert file path to dotted path. # Convert file path to dotted path.
@ -466,7 +461,7 @@ class LogXML:
def __init__( def __init__(
self, self,
logfile, logfile,
prefix: Optional[str], prefix: str | None,
suite_name: str = "pytest", suite_name: str = "pytest",
logging: str = "no", logging: str = "no",
report_duration: str = "total", report_duration: str = "total",
@ -481,17 +476,15 @@ class LogXML:
self.log_passing_tests = log_passing_tests self.log_passing_tests = log_passing_tests
self.report_duration = report_duration self.report_duration = report_duration
self.family = family self.family = family
self.stats: Dict[str, int] = dict.fromkeys( self.stats: dict[str, int] = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0 ["error", "passed", "failure", "skipped"], 0
) )
self.node_reporters: Dict[ self.node_reporters: dict[tuple[str | TestReport, object], _NodeReporter] = {}
Tuple[Union[str, TestReport], object], _NodeReporter self.node_reporters_ordered: list[_NodeReporter] = []
] = {} self.global_properties: list[tuple[str, str]] = []
self.node_reporters_ordered: List[_NodeReporter] = []
self.global_properties: List[Tuple[str, str]] = []
# List of reports that failed on call but teardown is pending. # List of reports that failed on call but teardown is pending.
self.open_reports: List[TestReport] = [] self.open_reports: list[TestReport] = []
self.cnt_double_fail_tests = 0 self.cnt_double_fail_tests = 0
# Replaces convenience family with real family. # Replaces convenience family with real family.
@ -510,8 +503,8 @@ class LogXML:
if reporter is not None: if reporter is not None:
reporter.finalize() reporter.finalize()
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter: def node_reporter(self, report: TestReport | str) -> _NodeReporter:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report) nodeid: str | TestReport = getattr(report, "nodeid", report)
# Local hack to handle xdist report order. # Local hack to handle xdist report order.
workernode = getattr(report, "node", None) workernode = getattr(report, "node", None)
@ -691,7 +684,7 @@ class LogXML:
_check_record_param_type("name", name) _check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value))) self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self) -> Optional[ET.Element]: def _get_global_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any.""" """Return a Junit node containing custom properties, if any."""
if self.global_properties: if self.global_properties:
properties = ET.Element("properties") properties = ET.Element("properties")

View File

@ -1,16 +1,15 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Add backward compatibility support for the legacy py path type.""" """Add backward compatibility support for the legacy py path type."""
from __future__ import annotations
import dataclasses import dataclasses
from pathlib import Path from pathlib import Path
import shlex import shlex
import subprocess import subprocess
from typing import Final from typing import Final
from typing import final from typing import final
from typing import List
from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
from iniconfig import SectionWrapper from iniconfig import SectionWrapper
@ -50,8 +49,8 @@ class Testdir:
__test__ = False __test__ = False
CLOSE_STDIN: "Final" = Pytester.CLOSE_STDIN CLOSE_STDIN: Final = Pytester.CLOSE_STDIN
TimeoutExpired: "Final" = Pytester.TimeoutExpired TimeoutExpired: Final = Pytester.TimeoutExpired
def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None: def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
@ -145,7 +144,7 @@ class Testdir:
"""See :meth:`Pytester.copy_example`.""" """See :meth:`Pytester.copy_example`."""
return legacy_path(self._pytester.copy_example(name)) return legacy_path(self._pytester.copy_example(name))
def getnode(self, config: Config, arg) -> Optional[Union[Item, Collector]]: def getnode(self, config: Config, arg) -> Item | Collector | None:
"""See :meth:`Pytester.getnode`.""" """See :meth:`Pytester.getnode`."""
return self._pytester.getnode(config, arg) return self._pytester.getnode(config, arg)
@ -153,7 +152,7 @@ class Testdir:
"""See :meth:`Pytester.getpathnode`.""" """See :meth:`Pytester.getpathnode`."""
return self._pytester.getpathnode(path) return self._pytester.getpathnode(path)
def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]: def genitems(self, colitems: list[Item | Collector]) -> list[Item]:
"""See :meth:`Pytester.genitems`.""" """See :meth:`Pytester.genitems`."""
return self._pytester.genitems(colitems) return self._pytester.genitems(colitems)
@ -205,9 +204,7 @@ class Testdir:
source, configargs=configargs, withinit=withinit source, configargs=configargs, withinit=withinit
) )
def collect_by_name( def collect_by_name(self, modcol: Collector, name: str) -> Item | Collector | None:
self, modcol: Collector, name: str
) -> Optional[Union[Item, Collector]]:
"""See :meth:`Pytester.collect_by_name`.""" """See :meth:`Pytester.collect_by_name`."""
return self._pytester.collect_by_name(modcol, name) return self._pytester.collect_by_name(modcol, name)
@ -238,13 +235,11 @@ class Testdir:
"""See :meth:`Pytester.runpytest_subprocess`.""" """See :meth:`Pytester.runpytest_subprocess`."""
return self._pytester.runpytest_subprocess(*args, timeout=timeout) return self._pytester.runpytest_subprocess(*args, timeout=timeout)
def spawn_pytest( def spawn_pytest(self, string: str, expect_timeout: float = 10.0) -> pexpect.spawn:
self, string: str, expect_timeout: float = 10.0
) -> "pexpect.spawn":
"""See :meth:`Pytester.spawn_pytest`.""" """See :meth:`Pytester.spawn_pytest`."""
return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout) return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn": def spawn(self, cmd: str, expect_timeout: float = 10.0) -> pexpect.spawn:
"""See :meth:`Pytester.spawn`.""" """See :meth:`Pytester.spawn`."""
return self._pytester.spawn(cmd, expect_timeout=expect_timeout) return self._pytester.spawn(cmd, expect_timeout=expect_timeout)
@ -374,7 +369,7 @@ def Config_rootdir(self: Config) -> LEGACY_PATH:
return legacy_path(str(self.rootpath)) return legacy_path(str(self.rootpath))
def Config_inifile(self: Config) -> Optional[LEGACY_PATH]: def Config_inifile(self: Config) -> LEGACY_PATH | None:
"""The path to the :ref:`configfile <configfiles>`. """The path to the :ref:`configfile <configfiles>`.
Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`. Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`.
@ -394,9 +389,7 @@ def Session_startdir(self: Session) -> LEGACY_PATH:
return legacy_path(self.startpath) return legacy_path(self.startpath)
def Config__getini_unknown_type( def Config__getini_unknown_type(self, name: str, type: str, value: str | list[str]):
self, name: str, type: str, value: Union[str, List[str]]
):
if type == "pathlist": if type == "pathlist":
# TODO: This assert is probably not valid in all cases. # TODO: This assert is probably not valid in all cases.
assert self.inipath is not None assert self.inipath is not None

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Access and control log capturing.""" """Access and control log capturing."""
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
@ -22,12 +24,8 @@ from typing import Generic
from typing import List from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union
from _pytest import nodes from _pytest import nodes
from _pytest._io import TerminalWriter from _pytest._io import TerminalWriter
@ -68,7 +66,7 @@ class DatetimeFormatter(logging.Formatter):
:func:`time.strftime` in case of microseconds in format string. :func:`time.strftime` in case of microseconds in format string.
""" """
def formatTime(self, record: LogRecord, datefmt: Optional[str] = None) -> str: def formatTime(self, record: LogRecord, datefmt: str | None = None) -> str:
if datefmt and "%f" in datefmt: if datefmt and "%f" in datefmt:
ct = self.converter(record.created) ct = self.converter(record.created)
tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone) tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone)
@ -100,7 +98,7 @@ class ColoredLevelFormatter(DatetimeFormatter):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._terminalwriter = terminalwriter self._terminalwriter = terminalwriter
self._original_fmt = self._style._fmt self._original_fmt = self._style._fmt
self._level_to_fmt_mapping: Dict[int, str] = {} self._level_to_fmt_mapping: dict[int, str] = {}
for level, color_opts in self.LOGLEVEL_COLOROPTS.items(): for level, color_opts in self.LOGLEVEL_COLOROPTS.items():
self.add_color_level(level, *color_opts) self.add_color_level(level, *color_opts)
@ -148,12 +146,12 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately. formats the message as if each line were logged separately.
""" """
def __init__(self, fmt: str, auto_indent: Union[int, str, bool, None]) -> None: def __init__(self, fmt: str, auto_indent: int | str | bool | None) -> None:
super().__init__(fmt) super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent) self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod @staticmethod
def _get_auto_indent(auto_indent_option: Union[int, str, bool, None]) -> int: def _get_auto_indent(auto_indent_option: int | str | bool | None) -> int:
"""Determine the current auto indentation setting. """Determine the current auto indentation setting.
Specify auto indent behavior (on/off/fixed) by passing in Specify auto indent behavior (on/off/fixed) by passing in
@ -348,7 +346,7 @@ class catching_logs(Generic[_HandlerType]):
__slots__ = ("handler", "level", "orig_level") __slots__ = ("handler", "level", "orig_level")
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None: def __init__(self, handler: _HandlerType, level: int | None = None) -> None:
self.handler = handler self.handler = handler
self.level = level self.level = level
@ -364,9 +362,9 @@ class catching_logs(Generic[_HandlerType]):
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
root_logger = logging.getLogger() root_logger = logging.getLogger()
if self.level is not None: if self.level is not None:
@ -380,7 +378,7 @@ class LogCaptureHandler(logging_StreamHandler):
def __init__(self) -> None: def __init__(self) -> None:
"""Create a new log handler.""" """Create a new log handler."""
super().__init__(StringIO()) super().__init__(StringIO())
self.records: List[logging.LogRecord] = [] self.records: list[logging.LogRecord] = []
def emit(self, record: logging.LogRecord) -> None: def emit(self, record: logging.LogRecord) -> None:
"""Keep the log records in a list in addition to the log text.""" """Keep the log records in a list in addition to the log text."""
@ -411,10 +409,10 @@ class LogCaptureFixture:
def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None: def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._item = item self._item = item
self._initial_handler_level: Optional[int] = None self._initial_handler_level: int | None = None
# Dict of log name -> log level. # Dict of log name -> log level.
self._initial_logger_levels: Dict[Optional[str], int] = {} self._initial_logger_levels: dict[str | None, int] = {}
self._initial_disabled_logging_level: Optional[int] = None self._initial_disabled_logging_level: int | None = None
def _finalize(self) -> None: def _finalize(self) -> None:
"""Finalize the fixture. """Finalize the fixture.
@ -439,7 +437,7 @@ class LogCaptureFixture:
def get_records( def get_records(
self, when: Literal["setup", "call", "teardown"] self, when: Literal["setup", "call", "teardown"]
) -> List[logging.LogRecord]: ) -> list[logging.LogRecord]:
"""Get the logging records for one of the possible test phases. """Get the logging records for one of the possible test phases.
:param when: :param when:
@ -458,12 +456,12 @@ class LogCaptureFixture:
return _remove_ansi_escape_sequences(self.handler.stream.getvalue()) return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property @property
def records(self) -> List[logging.LogRecord]: def records(self) -> list[logging.LogRecord]:
"""The list of log records.""" """The list of log records."""
return self.handler.records return self.handler.records
@property @property
def record_tuples(self) -> List[Tuple[str, int, str]]: def record_tuples(self) -> list[tuple[str, int, str]]:
"""A list of a stripped down version of log records intended """A list of a stripped down version of log records intended
for use in assertion comparison. for use in assertion comparison.
@ -474,7 +472,7 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records] return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property @property
def messages(self) -> List[str]: def messages(self) -> list[str]:
"""A list of format-interpolated log messages. """A list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for Unlike 'records', which contains the format string and parameters for
@ -497,7 +495,7 @@ class LogCaptureFixture:
self.handler.clear() self.handler.clear()
def _force_enable_logging( def _force_enable_logging(
self, level: Union[int, str], logger_obj: logging.Logger self, level: int | str, logger_obj: logging.Logger
) -> int: ) -> int:
"""Enable the desired logging level if the global level was disabled via ``logging.disabled``. """Enable the desired logging level if the global level was disabled via ``logging.disabled``.
@ -530,7 +528,7 @@ class LogCaptureFixture:
return original_disable_level return original_disable_level
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None: def set_level(self, level: int | str, logger: str | None = None) -> None:
"""Set the threshold level of a logger for the duration of a test. """Set the threshold level of a logger for the duration of a test.
Logging messages which are less severe than this level will not be captured. Logging messages which are less severe than this level will not be captured.
@ -557,7 +555,7 @@ class LogCaptureFixture:
@contextmanager @contextmanager
def at_level( def at_level(
self, level: Union[int, str], logger: Optional[str] = None self, level: int | str, logger: str | None = None
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After """Context manager that sets the level for capturing of logs. After
the end of the 'with' statement the level is restored to its original the end of the 'with' statement the level is restored to its original
@ -615,7 +613,7 @@ def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
result._finalize() result._finalize()
def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[int]: def get_log_level_for_setting(config: Config, *setting_names: str) -> int | None:
for setting_name in setting_names: for setting_name in setting_names:
log_level = config.getoption(setting_name) log_level = config.getoption(setting_name)
if log_level is None: if log_level is None:
@ -701,9 +699,9 @@ class LoggingPlugin:
assert terminal_reporter is not None assert terminal_reporter is not None
capture_manager = config.pluginmanager.get_plugin("capturemanager") capture_manager = config.pluginmanager.get_plugin("capturemanager")
# if capturemanager plugin is disabled, live logging still works. # if capturemanager plugin is disabled, live logging still works.
self.log_cli_handler: Union[ self.log_cli_handler: (
_LiveLoggingStreamHandler, _LiveLoggingNullHandler _LiveLoggingStreamHandler | _LiveLoggingNullHandler
] = _LiveLoggingStreamHandler(terminal_reporter, capture_manager) ) = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
else: else:
self.log_cli_handler = _LiveLoggingNullHandler() self.log_cli_handler = _LiveLoggingNullHandler()
log_cli_formatter = self._create_formatter( log_cli_formatter = self._create_formatter(
@ -714,7 +712,7 @@ class LoggingPlugin:
self.log_cli_handler.setFormatter(log_cli_formatter) self.log_cli_handler.setFormatter(log_cli_formatter)
self._disable_loggers(loggers_to_disable=config.option.logger_disable) self._disable_loggers(loggers_to_disable=config.option.logger_disable)
def _disable_loggers(self, loggers_to_disable: List[str]) -> None: def _disable_loggers(self, loggers_to_disable: list[str]) -> None:
if not loggers_to_disable: if not loggers_to_disable:
return return
@ -839,7 +837,7 @@ class LoggingPlugin:
def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]: def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("setup") self.log_cli_handler.set_when("setup")
empty: Dict[str, List[logging.LogRecord]] = {} empty: dict[str, list[logging.LogRecord]] = {}
item.stash[caplog_records_key] = empty item.stash[caplog_records_key] = empty
yield from self._runtest_for(item, "setup") yield from self._runtest_for(item, "setup")
@ -902,7 +900,7 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
def __init__( def __init__(
self, self,
terminal_reporter: TerminalReporter, terminal_reporter: TerminalReporter,
capture_manager: Optional[CaptureManager], capture_manager: CaptureManager | None,
) -> None: ) -> None:
super().__init__(stream=terminal_reporter) # type: ignore[arg-type] super().__init__(stream=terminal_reporter) # type: ignore[arg-type]
self.capture_manager = capture_manager self.capture_manager = capture_manager
@ -914,7 +912,7 @@ class _LiveLoggingStreamHandler(logging_StreamHandler):
"""Reset the handler; should be called before the start of each test.""" """Reset the handler; should be called before the start of each test."""
self._first_record_emitted = False self._first_record_emitted = False
def set_when(self, when: Optional[str]) -> None: def set_when(self, when: str | None) -> None:
"""Prepare for the given test phase (setup/call/teardown).""" """Prepare for the given test phase (setup/call/teardown)."""
self._when = when self._when = when
self._section_name_shown = False self._section_name_shown = False

View File

@ -1,5 +1,7 @@
"""Core implementation of the testing process: init, session, runtest loop.""" """Core implementation of the testing process: init, session, runtest loop."""
from __future__ import annotations
import argparse import argparse
import dataclasses import dataclasses
import fnmatch import fnmatch
@ -13,17 +15,12 @@ from typing import AbstractSet
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import final from typing import final
from typing import FrozenSet
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Literal from typing import Literal
from typing import Optional
from typing import overload from typing import overload
from typing import Sequence from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import warnings import warnings
import pluggy import pluggy
@ -270,8 +267,8 @@ def validate_basetemp(path: str) -> str:
def wrap_session( def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]] config: Config, doit: Callable[[Config, Session], int | ExitCode | None]
) -> Union[int, ExitCode]: ) -> int | ExitCode:
"""Skeleton command line program.""" """Skeleton command line program."""
session = Session.from_config(config) session = Session.from_config(config)
session.exitstatus = ExitCode.OK session.exitstatus = ExitCode.OK
@ -290,7 +287,7 @@ def wrap_session(
session.exitstatus = ExitCode.TESTS_FAILED session.exitstatus = ExitCode.TESTS_FAILED
except (KeyboardInterrupt, exit.Exception): except (KeyboardInterrupt, exit.Exception):
excinfo = _pytest._code.ExceptionInfo.from_current() excinfo = _pytest._code.ExceptionInfo.from_current()
exitstatus: Union[int, ExitCode] = ExitCode.INTERRUPTED exitstatus: int | ExitCode = ExitCode.INTERRUPTED
if isinstance(excinfo.value, exit.Exception): if isinstance(excinfo.value, exit.Exception):
if excinfo.value.returncode is not None: if excinfo.value.returncode is not None:
exitstatus = excinfo.value.returncode exitstatus = excinfo.value.returncode
@ -328,11 +325,11 @@ def wrap_session(
return session.exitstatus return session.exitstatus
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]: def pytest_cmdline_main(config: Config) -> int | ExitCode:
return wrap_session(config, _main) return wrap_session(config, _main)
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]: def _main(config: Config, session: Session) -> int | ExitCode | None:
"""Default command line protocol for initialization, session, """Default command line protocol for initialization, session,
running tests and reporting.""" running tests and reporting."""
config.hook.pytest_collection(session=session) config.hook.pytest_collection(session=session)
@ -345,11 +342,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None return None
def pytest_collection(session: "Session") -> None: def pytest_collection(session: Session) -> None:
session.perform_collect() session.perform_collect()
def pytest_runtestloop(session: "Session") -> bool: def pytest_runtestloop(session: Session) -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors: if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted( raise session.Interrupted(
"%d error%s during collection" "%d error%s during collection"
@ -389,7 +386,7 @@ def _in_venv(path: Path) -> bool:
return any(fname.name in activates for fname in bindir.iterdir()) return any(fname.name in activates for fname in bindir.iterdir())
def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[bool]: def pytest_ignore_collect(collection_path: Path, config: Config) -> bool | None:
if collection_path.name == "__pycache__": if collection_path.name == "__pycache__":
return True return True
@ -429,11 +426,11 @@ def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[boo
def pytest_collect_directory( def pytest_collect_directory(
path: Path, parent: nodes.Collector path: Path, parent: nodes.Collector
) -> Optional[nodes.Collector]: ) -> nodes.Collector | None:
return Dir.from_parent(parent, path=path) return Dir.from_parent(parent, path=path)
def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None: def pytest_collection_modifyitems(items: list[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or []) deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes: if not deselect_prefixes:
return return
@ -507,17 +504,18 @@ class Dir(nodes.Directory):
parent: nodes.Collector, parent: nodes.Collector,
*, *,
path: Path, path: Path,
) -> "Self": ) -> Self:
"""The public constructor. """The public constructor.
:param parent: The parent collector of this Dir. :param parent: The parent collector of this Dir.
:param path: The directory's path. :param path: The directory's path.
:type path: pathlib.Path
""" """
return super().from_parent(parent=parent, path=path) return super().from_parent(parent=parent, path=path)
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
config = self.config config = self.config
col: Optional[nodes.Collector] col: nodes.Collector | None
cols: Sequence[nodes.Collector] cols: Sequence[nodes.Collector]
ihook = self.ihook ihook = self.ihook
for direntry in scandir(self.path): for direntry in scandir(self.path):
@ -552,7 +550,7 @@ class Session(nodes.Collector):
_setupstate: SetupState _setupstate: SetupState
# Set on the session by fixtures.pytest_sessionstart. # Set on the session by fixtures.pytest_sessionstart.
_fixturemanager: FixtureManager _fixturemanager: FixtureManager
exitstatus: Union[int, ExitCode] exitstatus: int | ExitCode
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
super().__init__( super().__init__(
@ -566,22 +564,22 @@ class Session(nodes.Collector):
) )
self.testsfailed = 0 self.testsfailed = 0
self.testscollected = 0 self.testscollected = 0
self._shouldstop: Union[bool, str] = False self._shouldstop: bool | str = False
self._shouldfail: Union[bool, str] = False self._shouldfail: bool | str = False
self.trace = config.trace.root.get("collection") self.trace = config.trace.root.get("collection")
self._initialpaths: FrozenSet[Path] = frozenset() self._initialpaths: frozenset[Path] = frozenset()
self._initialpaths_with_parents: FrozenSet[Path] = frozenset() self._initialpaths_with_parents: frozenset[Path] = frozenset()
self._notfound: List[Tuple[str, Sequence[nodes.Collector]]] = [] self._notfound: list[tuple[str, Sequence[nodes.Collector]]] = []
self._initial_parts: List[CollectionArgument] = [] self._initial_parts: list[CollectionArgument] = []
self._collection_cache: Dict[nodes.Collector, CollectReport] = {} self._collection_cache: dict[nodes.Collector, CollectReport] = {}
self.items: List[nodes.Item] = [] self.items: list[nodes.Item] = []
self._bestrelpathcache: Dict[Path, str] = _bestrelpath_cache(config.rootpath) self._bestrelpathcache: dict[Path, str] = _bestrelpath_cache(config.rootpath)
self.config.pluginmanager.register(self, name="session") self.config.pluginmanager.register(self, name="session")
@classmethod @classmethod
def from_config(cls, config: Config) -> "Session": def from_config(cls, config: Config) -> Session:
session: Session = cls._create(config=config) session: Session = cls._create(config=config)
return session return session
@ -595,11 +593,11 @@ class Session(nodes.Collector):
) )
@property @property
def shouldstop(self) -> Union[bool, str]: def shouldstop(self) -> bool | str:
return self._shouldstop return self._shouldstop
@shouldstop.setter @shouldstop.setter
def shouldstop(self, value: Union[bool, str]) -> None: def shouldstop(self, value: bool | str) -> None:
# The runner checks shouldfail and assumes that if it is set we are # The runner checks shouldfail and assumes that if it is set we are
# definitely stopping, so prevent unsetting it. # definitely stopping, so prevent unsetting it.
if value is False and self._shouldstop: if value is False and self._shouldstop:
@ -613,11 +611,11 @@ class Session(nodes.Collector):
self._shouldstop = value self._shouldstop = value
@property @property
def shouldfail(self) -> Union[bool, str]: def shouldfail(self) -> bool | str:
return self._shouldfail return self._shouldfail
@shouldfail.setter @shouldfail.setter
def shouldfail(self, value: Union[bool, str]) -> None: def shouldfail(self, value: bool | str) -> None:
# The runner checks shouldfail and assumes that if it is set we are # The runner checks shouldfail and assumes that if it is set we are
# definitely stopping, so prevent unsetting it. # definitely stopping, so prevent unsetting it.
if value is False and self._shouldfail: if value is False and self._shouldfail:
@ -650,9 +648,7 @@ class Session(nodes.Collector):
raise self.Interrupted(self.shouldstop) raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_runtest_logreport( def pytest_runtest_logreport(self, report: TestReport | CollectReport) -> None:
self, report: Union[TestReport, CollectReport]
) -> None:
if report.failed and not hasattr(report, "wasxfail"): if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1 self.testsfailed += 1
maxfail = self.config.getvalue("maxfail") maxfail = self.config.getvalue("maxfail")
@ -663,7 +659,7 @@ class Session(nodes.Collector):
def isinitpath( def isinitpath(
self, self,
path: Union[str, "os.PathLike[str]"], path: str | os.PathLike[str],
*, *,
with_parents: bool = False, with_parents: bool = False,
) -> bool: ) -> bool:
@ -685,7 +681,7 @@ class Session(nodes.Collector):
else: else:
return path_ in self._initialpaths return path_ in self._initialpaths
def gethookproxy(self, fspath: "os.PathLike[str]") -> pluggy.HookRelay: def gethookproxy(self, fspath: os.PathLike[str]) -> pluggy.HookRelay:
# Optimization: Path(Path(...)) is much slower than isinstance. # Optimization: Path(Path(...)) is much slower than isinstance.
path = fspath if isinstance(fspath, Path) else Path(fspath) path = fspath if isinstance(fspath, Path) else Path(fspath)
pm = self.config.pluginmanager pm = self.config.pluginmanager
@ -705,7 +701,7 @@ class Session(nodes.Collector):
def _collect_path( def _collect_path(
self, self,
path: Path, path: Path,
path_cache: Dict[Path, Sequence[nodes.Collector]], path_cache: dict[Path, Sequence[nodes.Collector]],
) -> Sequence[nodes.Collector]: ) -> Sequence[nodes.Collector]:
"""Create a Collector for the given path. """Create a Collector for the given path.
@ -717,7 +713,7 @@ class Session(nodes.Collector):
if path.is_dir(): if path.is_dir():
ihook = self.gethookproxy(path.parent) ihook = self.gethookproxy(path.parent)
col: Optional[nodes.Collector] = ihook.pytest_collect_directory( col: nodes.Collector | None = ihook.pytest_collect_directory(
path=path, parent=self path=path, parent=self
) )
cols: Sequence[nodes.Collector] = (col,) if col is not None else () cols: Sequence[nodes.Collector] = (col,) if col is not None else ()
@ -735,17 +731,17 @@ class Session(nodes.Collector):
@overload @overload
def perform_collect( def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ... self, args: Sequence[str] | None = ..., genitems: Literal[True] = ...
) -> Sequence[nodes.Item]: ... ) -> Sequence[nodes.Item]: ...
@overload @overload
def perform_collect( def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: bool = ... self, args: Sequence[str] | None = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]: ... ) -> Sequence[nodes.Item | nodes.Collector]: ...
def perform_collect( def perform_collect(
self, args: Optional[Sequence[str]] = None, genitems: bool = True self, args: Sequence[str] | None = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]: ) -> Sequence[nodes.Item | nodes.Collector]:
"""Perform the collection phase for this session. """Perform the collection phase for this session.
This is called by the default :hook:`pytest_collection` hook This is called by the default :hook:`pytest_collection` hook
@ -771,10 +767,10 @@ class Session(nodes.Collector):
self._initial_parts = [] self._initial_parts = []
self._collection_cache = {} self._collection_cache = {}
self.items = [] self.items = []
items: Sequence[Union[nodes.Item, nodes.Collector]] = self.items items: Sequence[nodes.Item | nodes.Collector] = self.items
try: try:
initialpaths: List[Path] = [] initialpaths: list[Path] = []
initialpaths_with_parents: List[Path] = [] initialpaths_with_parents: list[Path] = []
for arg in args: for arg in args:
collection_argument = resolve_collection_argument( collection_argument = resolve_collection_argument(
self.config.invocation_params.dir, self.config.invocation_params.dir,
@ -829,7 +825,7 @@ class Session(nodes.Collector):
self, self,
node: nodes.Collector, node: nodes.Collector,
handle_dupes: bool = True, handle_dupes: bool = True,
) -> Tuple[CollectReport, bool]: ) -> tuple[CollectReport, bool]:
if node in self._collection_cache and handle_dupes: if node in self._collection_cache and handle_dupes:
rep = self._collection_cache[node] rep = self._collection_cache[node]
return rep, True return rep, True
@ -838,11 +834,11 @@ class Session(nodes.Collector):
self._collection_cache[node] = rep self._collection_cache[node] = rep
return rep, False return rep, False
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterator[nodes.Item | nodes.Collector]:
# This is a cache for the root directories of the initial paths. # This is a cache for the root directories of the initial paths.
# We can't use collection_cache for Session because of its special # We can't use collection_cache for Session because of its special
# role as the bootstrapping collector. # role as the bootstrapping collector.
path_cache: Dict[Path, Sequence[nodes.Collector]] = {} path_cache: dict[Path, Sequence[nodes.Collector]] = {}
pm = self.config.pluginmanager pm = self.config.pluginmanager
@ -880,9 +876,9 @@ class Session(nodes.Collector):
# and discarding all nodes which don't match the level's part. # and discarding all nodes which don't match the level's part.
any_matched_in_initial_part = False any_matched_in_initial_part = False
notfound_collectors = [] notfound_collectors = []
work: List[ work: list[tuple[nodes.Collector | nodes.Item, list[Path | str]]] = [
Tuple[Union[nodes.Collector, nodes.Item], List[Union[Path, str]]] (self, [*paths, *names])
] = [(self, [*paths, *names])] ]
while work: while work:
matchnode, matchparts = work.pop() matchnode, matchparts = work.pop()
@ -899,7 +895,7 @@ class Session(nodes.Collector):
# Collect this level of matching. # Collect this level of matching.
# Collecting Session (self) is done directly to avoid endless # Collecting Session (self) is done directly to avoid endless
# recursion to this function. # recursion to this function.
subnodes: Sequence[Union[nodes.Collector, nodes.Item]] subnodes: Sequence[nodes.Collector | nodes.Item]
if isinstance(matchnode, Session): if isinstance(matchnode, Session):
assert isinstance(matchparts[0], Path) assert isinstance(matchparts[0], Path)
subnodes = matchnode._collect_path(matchparts[0], path_cache) subnodes = matchnode._collect_path(matchparts[0], path_cache)
@ -959,9 +955,7 @@ class Session(nodes.Collector):
self.trace.root.indent -= 1 self.trace.root.indent -= 1
def genitems( def genitems(self, node: nodes.Item | nodes.Collector) -> Iterator[nodes.Item]:
self, node: Union[nodes.Item, nodes.Collector]
) -> Iterator[nodes.Item]:
self.trace("genitems", node) self.trace("genitems", node)
if isinstance(node, nodes.Item): if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node) node.ihook.pytest_itemcollected(item=node)
@ -981,7 +975,7 @@ class Session(nodes.Collector):
node.ihook.pytest_collectreport(report=rep) node.ihook.pytest_collectreport(report=rep)
def search_pypath(module_name: str) -> Optional[str]: def search_pypath(module_name: str) -> str | None:
"""Search sys.path for the given a dotted module name, and return its file """Search sys.path for the given a dotted module name, and return its file
system path if found.""" system path if found."""
try: try:
@ -1005,7 +999,7 @@ class CollectionArgument:
path: Path path: Path
parts: Sequence[str] parts: Sequence[str]
module_name: Optional[str] module_name: str | None
def resolve_collection_argument( def resolve_collection_argument(

View File

@ -1,12 +1,12 @@
"""Generic mechanism for marking and selecting python functions.""" """Generic mechanism for marking and selecting python functions."""
from __future__ import annotations
import dataclasses import dataclasses
from typing import AbstractSet from typing import AbstractSet
from typing import Collection from typing import Collection
from typing import List
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
from .expression import Expression from .expression import Expression
from .expression import ParseError from .expression import ParseError
@ -44,8 +44,8 @@ old_mark_config_key = StashKey[Optional[Config]]()
def param( def param(
*values: object, *values: object,
marks: Union[MarkDecorator, Collection[Union[MarkDecorator, Mark]]] = (), marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: Optional[str] = None, id: str | None = None,
) -> ParameterSet: ) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or """Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`. :ref:`parametrized fixtures <fixture-parametrize-marks>`.
@ -112,7 +112,7 @@ def pytest_addoption(parser: Parser) -> None:
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
import _pytest.config import _pytest.config
if config.option.markers: if config.option.markers:
@ -151,7 +151,7 @@ class KeywordMatcher:
_names: AbstractSet[str] _names: AbstractSet[str]
@classmethod @classmethod
def from_item(cls, item: "Item") -> "KeywordMatcher": def from_item(cls, item: Item) -> KeywordMatcher:
mapped_names = set() mapped_names = set()
# Add the names of the current item and any parent items, # Add the names of the current item and any parent items,
@ -191,7 +191,7 @@ class KeywordMatcher:
return False return False
def deselect_by_keyword(items: "List[Item]", config: Config) -> None: def deselect_by_keyword(items: list[Item], config: Config) -> None:
keywordexpr = config.option.keyword.lstrip() keywordexpr = config.option.keyword.lstrip()
if not keywordexpr: if not keywordexpr:
return return
@ -223,7 +223,7 @@ class MarkMatcher:
own_mark_names: AbstractSet[str] own_mark_names: AbstractSet[str]
@classmethod @classmethod
def from_item(cls, item: "Item") -> "MarkMatcher": def from_item(cls, item: Item) -> MarkMatcher:
mark_names = {mark.name for mark in item.iter_markers()} mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names) return cls(mark_names)
@ -231,14 +231,14 @@ class MarkMatcher:
return name in self.own_mark_names return name in self.own_mark_names
def deselect_by_mark(items: "List[Item]", config: Config) -> None: def deselect_by_mark(items: list[Item], config: Config) -> None:
matchexpr = config.option.markexpr matchexpr = config.option.markexpr
if not matchexpr: if not matchexpr:
return return
expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'") expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'")
remaining: List[Item] = [] remaining: list[Item] = []
deselected: List[Item] = [] deselected: list[Item] = []
for item in items: for item in items:
if expr.evaluate(MarkMatcher.from_item(item)): if expr.evaluate(MarkMatcher.from_item(item)):
remaining.append(item) remaining.append(item)
@ -256,7 +256,7 @@ def _parse_expression(expr: str, exc_message: str) -> Expression:
raise UsageError(f"{exc_message}: {expr}: {e}") from None raise UsageError(f"{exc_message}: {expr}: {e}") from None
def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None: def pytest_collection_modifyitems(items: list[Item], config: Config) -> None:
deselect_by_keyword(items, config) deselect_by_keyword(items, config)
deselect_by_mark(items, config) deselect_by_mark(items, config)

View File

@ -15,6 +15,8 @@ The semantics are:
- or/and/not evaluate according to the usual boolean semantics. - or/and/not evaluate according to the usual boolean semantics.
""" """
from __future__ import annotations
import ast import ast
import dataclasses import dataclasses
import enum import enum
@ -24,7 +26,6 @@ from typing import Callable
from typing import Iterator from typing import Iterator
from typing import Mapping from typing import Mapping
from typing import NoReturn from typing import NoReturn
from typing import Optional
from typing import Sequence from typing import Sequence
@ -105,7 +106,7 @@ class Scanner:
) )
yield Token(TokenType.EOF, "", pos) yield Token(TokenType.EOF, "", pos)
def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]: def accept(self, type: TokenType, *, reject: bool = False) -> Token | None:
if self.current.type is type: if self.current.type is type:
token = self.current token = self.current
if token.type is not TokenType.EOF: if token.type is not TokenType.EOF:
@ -197,7 +198,7 @@ class Expression:
self.code = code self.code = code
@classmethod @classmethod
def compile(self, input: str) -> "Expression": def compile(self, input: str) -> Expression:
"""Compile a match expression. """Compile a match expression.
:param input: The input expression - one line. :param input: The input expression - one line.

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import collections.abc import collections.abc
import dataclasses import dataclasses
import inspect import inspect
@ -8,16 +10,11 @@ from typing import Collection
from typing import final from typing import final
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Mapping from typing import Mapping
from typing import MutableMapping from typing import MutableMapping
from typing import NamedTuple from typing import NamedTuple
from typing import Optional
from typing import overload from typing import overload
from typing import Sequence from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
@ -47,7 +44,7 @@ def istestfunc(func) -> bool:
def get_empty_parameterset_mark( def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func config: Config, argnames: Sequence[str], func
) -> "MarkDecorator": ) -> MarkDecorator:
from ..nodes import Collector from ..nodes import Collector
fs, lineno = getfslineno(func) fs, lineno = getfslineno(func)
@ -75,17 +72,17 @@ def get_empty_parameterset_mark(
class ParameterSet(NamedTuple): class ParameterSet(NamedTuple):
values: Sequence[Union[object, NotSetType]] values: Sequence[object | NotSetType]
marks: Collection[Union["MarkDecorator", "Mark"]] marks: Collection[MarkDecorator | Mark]
id: Optional[str] id: str | None
@classmethod @classmethod
def param( def param(
cls, cls,
*values: object, *values: object,
marks: Union["MarkDecorator", Collection[Union["MarkDecorator", "Mark"]]] = (), marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: Optional[str] = None, id: str | None = None,
) -> "ParameterSet": ) -> ParameterSet:
if isinstance(marks, MarkDecorator): if isinstance(marks, MarkDecorator):
marks = (marks,) marks = (marks,)
else: else:
@ -100,9 +97,9 @@ class ParameterSet(NamedTuple):
@classmethod @classmethod
def extract_from( def extract_from(
cls, cls,
parameterset: Union["ParameterSet", Sequence[object], object], parameterset: ParameterSet | Sequence[object] | object,
force_tuple: bool = False, force_tuple: bool = False,
) -> "ParameterSet": ) -> ParameterSet:
"""Extract from an object or objects. """Extract from an object or objects.
:param parameterset: :param parameterset:
@ -127,11 +124,11 @@ class ParameterSet(NamedTuple):
@staticmethod @staticmethod
def _parse_parametrize_args( def _parse_parametrize_args(
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
*args, *args,
**kwargs, **kwargs,
) -> Tuple[Sequence[str], bool]: ) -> tuple[Sequence[str], bool]:
if isinstance(argnames, str): if isinstance(argnames, str):
argnames = [x.strip() for x in argnames.split(",") if x.strip()] argnames = [x.strip() for x in argnames.split(",") if x.strip()]
force_tuple = len(argnames) == 1 force_tuple = len(argnames) == 1
@ -141,9 +138,9 @@ class ParameterSet(NamedTuple):
@staticmethod @staticmethod
def _parse_parametrize_parameters( def _parse_parametrize_parameters(
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
force_tuple: bool, force_tuple: bool,
) -> List["ParameterSet"]: ) -> list[ParameterSet]:
return [ return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
] ]
@ -151,12 +148,12 @@ class ParameterSet(NamedTuple):
@classmethod @classmethod
def _for_parametrize( def _for_parametrize(
cls, cls,
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
func, func,
config: Config, config: Config,
nodeid: str, nodeid: str,
) -> Tuple[Sequence[str], List["ParameterSet"]]: ) -> tuple[Sequence[str], list[ParameterSet]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues) argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple) parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues del argvalues
@ -199,24 +196,24 @@ class Mark:
#: Name of the mark. #: Name of the mark.
name: str name: str
#: Positional arguments of the mark decorator. #: Positional arguments of the mark decorator.
args: Tuple[Any, ...] args: tuple[Any, ...]
#: Keyword arguments of the mark decorator. #: Keyword arguments of the mark decorator.
kwargs: Mapping[str, Any] kwargs: Mapping[str, Any]
#: Source Mark for ids with parametrize Marks. #: Source Mark for ids with parametrize Marks.
_param_ids_from: Optional["Mark"] = dataclasses.field(default=None, repr=False) _param_ids_from: Mark | None = dataclasses.field(default=None, repr=False)
#: Resolved/generated ids with parametrize Marks. #: Resolved/generated ids with parametrize Marks.
_param_ids_generated: Optional[Sequence[str]] = dataclasses.field( _param_ids_generated: Sequence[str] | None = dataclasses.field(
default=None, repr=False default=None, repr=False
) )
def __init__( def __init__(
self, self,
name: str, name: str,
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Mapping[str, Any], kwargs: Mapping[str, Any],
param_ids_from: Optional["Mark"] = None, param_ids_from: Mark | None = None,
param_ids_generated: Optional[Sequence[str]] = None, param_ids_generated: Sequence[str] | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -232,7 +229,7 @@ class Mark:
def _has_param_ids(self) -> bool: def _has_param_ids(self) -> bool:
return "ids" in self.kwargs or len(self.args) >= 4 return "ids" in self.kwargs or len(self.args) >= 4
def combined_with(self, other: "Mark") -> "Mark": def combined_with(self, other: Mark) -> Mark:
"""Return a new Mark which is a combination of this """Return a new Mark which is a combination of this
Mark and another Mark. Mark and another Mark.
@ -244,7 +241,7 @@ class Mark:
assert self.name == other.name assert self.name == other.name
# Remember source of ids with parametrize Marks. # Remember source of ids with parametrize Marks.
param_ids_from: Optional[Mark] = None param_ids_from: Mark | None = None
if self.name == "parametrize": if self.name == "parametrize":
if other._has_param_ids(): if other._has_param_ids():
param_ids_from = other param_ids_from = other
@ -315,7 +312,7 @@ class MarkDecorator:
return self.mark.name return self.mark.name
@property @property
def args(self) -> Tuple[Any, ...]: def args(self) -> tuple[Any, ...]:
"""Alias for mark.args.""" """Alias for mark.args."""
return self.mark.args return self.mark.args
@ -329,7 +326,7 @@ class MarkDecorator:
""":meta private:""" """:meta private:"""
return self.name # for backward-compat (2.4.1 had this attr) return self.name # for backward-compat (2.4.1 had this attr)
def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator": def with_args(self, *args: object, **kwargs: object) -> MarkDecorator:
"""Return a MarkDecorator with extra arguments added. """Return a MarkDecorator with extra arguments added.
Unlike calling the MarkDecorator, with_args() can be used even Unlike calling the MarkDecorator, with_args() can be used even
@ -346,7 +343,7 @@ class MarkDecorator:
pass pass
@overload @overload
def __call__(self, *args: object, **kwargs: object) -> "MarkDecorator": def __call__(self, *args: object, **kwargs: object) -> MarkDecorator:
pass pass
def __call__(self, *args: object, **kwargs: object): def __call__(self, *args: object, **kwargs: object):
@ -361,10 +358,10 @@ class MarkDecorator:
def get_unpacked_marks( def get_unpacked_marks(
obj: Union[object, type], obj: object | type,
*, *,
consider_mro: bool = True, consider_mro: bool = True,
) -> List[Mark]: ) -> list[Mark]:
"""Obtain the unpacked marks that are stored on an object. """Obtain the unpacked marks that are stored on an object.
If obj is a class and consider_mro is true, return marks applied to If obj is a class and consider_mro is true, return marks applied to
@ -394,7 +391,7 @@ def get_unpacked_marks(
def normalize_mark_list( def normalize_mark_list(
mark_list: Iterable[Union[Mark, MarkDecorator]], mark_list: Iterable[Mark | MarkDecorator],
) -> Iterable[Mark]: ) -> Iterable[Mark]:
""" """
Normalize an iterable of Mark or MarkDecorator objects into a list of marks Normalize an iterable of Mark or MarkDecorator objects into a list of marks
@ -437,13 +434,13 @@ if TYPE_CHECKING:
def __call__(self, arg: Markable) -> Markable: ... def __call__(self, arg: Markable) -> Markable: ...
@overload @overload
def __call__(self, reason: str = ...) -> "MarkDecorator": ... def __call__(self, reason: str = ...) -> MarkDecorator: ...
class _SkipifMarkDecorator(MarkDecorator): class _SkipifMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override] def __call__( # type: ignore[override]
self, self,
condition: Union[str, bool] = ..., condition: str | bool = ...,
*conditions: Union[str, bool], *conditions: str | bool,
reason: str = ..., reason: str = ...,
) -> MarkDecorator: ... ) -> MarkDecorator: ...
@ -454,30 +451,25 @@ if TYPE_CHECKING:
@overload @overload
def __call__( def __call__(
self, self,
condition: Union[str, bool] = False, condition: str | bool = False,
*conditions: Union[str, bool], *conditions: str | bool,
reason: str = ..., reason: str = ...,
run: bool = ..., run: bool = ...,
raises: Union[ raises: None | type[BaseException] | tuple[type[BaseException], ...] = ...,
None, Type[BaseException], Tuple[Type[BaseException], ...]
] = ...,
strict: bool = ..., strict: bool = ...,
) -> MarkDecorator: ... ) -> MarkDecorator: ...
class _ParametrizeMarkDecorator(MarkDecorator): class _ParametrizeMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override] def __call__( # type: ignore[override]
self, self,
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union[ParameterSet, Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
*, *,
indirect: Union[bool, Sequence[str]] = ..., indirect: bool | Sequence[str] = ...,
ids: Optional[ ids: Iterable[None | str | float | int | bool]
Union[ | Callable[[Any], object | None]
Iterable[Union[None, str, float, int, bool]], | None = ...,
Callable[[Any], Optional[object]], scope: _ScopeName | None = ...,
]
] = ...,
scope: Optional[_ScopeName] = ...,
) -> MarkDecorator: ... ) -> MarkDecorator: ...
class _UsefixturesMarkDecorator(MarkDecorator): class _UsefixturesMarkDecorator(MarkDecorator):
@ -517,8 +509,8 @@ class MarkGenerator:
def __init__(self, *, _ispytest: bool = False) -> None: def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._config: Optional[Config] = None self._config: Config | None = None
self._markers: Set[str] = set() self._markers: set[str] = set()
def __getattr__(self, name: str) -> MarkDecorator: def __getattr__(self, name: str) -> MarkDecorator:
"""Generate a new :class:`MarkDecorator` with the given name.""" """Generate a new :class:`MarkDecorator` with the given name."""
@ -569,7 +561,7 @@ MARK_GEN = MarkGenerator(_ispytest=True)
class NodeKeywords(MutableMapping[str, Any]): class NodeKeywords(MutableMapping[str, Any]):
__slots__ = ("node", "parent", "_markers") __slots__ = ("node", "parent", "_markers")
def __init__(self, node: "Node") -> None: def __init__(self, node: Node) -> None:
self.node = node self.node = node
self.parent = node.parent self.parent = node.parent
self._markers = {node.name: True} self._markers = {node.name: True}
@ -597,7 +589,7 @@ class NodeKeywords(MutableMapping[str, Any]):
def update( # type: ignore[override] def update( # type: ignore[override]
self, self,
other: Union[Mapping[str, Any], Iterable[Tuple[str, Any]]] = (), other: Mapping[str, Any] | Iterable[tuple[str, Any]] = (),
**kwds: Any, **kwds: Any,
) -> None: ) -> None:
self._markers.update(other) self._markers.update(other)

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Monkeypatching and mocking functionality.""" """Monkeypatching and mocking functionality."""
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
import os import os
import re import re
@ -8,14 +10,10 @@ import sys
from typing import Any from typing import Any
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import List
from typing import Mapping from typing import Mapping
from typing import MutableMapping from typing import MutableMapping
from typing import Optional
from typing import overload from typing import overload
from typing import Tuple
from typing import TypeVar from typing import TypeVar
from typing import Union
import warnings import warnings
from _pytest.fixtures import fixture from _pytest.fixtures import fixture
@ -30,7 +28,7 @@ V = TypeVar("V")
@fixture @fixture
def monkeypatch() -> Generator["MonkeyPatch", None, None]: def monkeypatch() -> Generator[MonkeyPatch, None, None]:
"""A convenient fixture for monkey-patching. """A convenient fixture for monkey-patching.
The fixture provides these methods to modify objects, dictionaries, or The fixture provides these methods to modify objects, dictionaries, or
@ -97,7 +95,7 @@ def annotated_getattr(obj: object, name: str, ann: str) -> object:
return obj return obj
def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]: def derive_importpath(import_path: str, raising: bool) -> tuple[str, object]:
if not isinstance(import_path, str) or "." not in import_path: if not isinstance(import_path, str) or "." not in import_path:
raise TypeError(f"must be absolute import path string, not {import_path!r}") raise TypeError(f"must be absolute import path string, not {import_path!r}")
module, attr = import_path.rsplit(".", 1) module, attr = import_path.rsplit(".", 1)
@ -130,14 +128,14 @@ class MonkeyPatch:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._setattr: List[Tuple[object, str, object]] = [] self._setattr: list[tuple[object, str, object]] = []
self._setitem: List[Tuple[Mapping[Any, Any], object, object]] = [] self._setitem: list[tuple[Mapping[Any, Any], object, object]] = []
self._cwd: Optional[str] = None self._cwd: str | None = None
self._savesyspath: Optional[List[str]] = None self._savesyspath: list[str] | None = None
@classmethod @classmethod
@contextmanager @contextmanager
def context(cls) -> Generator["MonkeyPatch", None, None]: def context(cls) -> Generator[MonkeyPatch, None, None]:
"""Context manager that returns a new :class:`MonkeyPatch` object """Context manager that returns a new :class:`MonkeyPatch` object
which undoes any patching done inside the ``with`` block upon exit. which undoes any patching done inside the ``with`` block upon exit.
@ -181,8 +179,8 @@ class MonkeyPatch:
def setattr( def setattr(
self, self,
target: Union[str, object], target: str | object,
name: Union[object, str], name: object | str,
value: object = notset, value: object = notset,
raising: bool = True, raising: bool = True,
) -> None: ) -> None:
@ -253,8 +251,8 @@ class MonkeyPatch:
def delattr( def delattr(
self, self,
target: Union[object, str], target: object | str,
name: Union[str, Notset] = notset, name: str | Notset = notset,
raising: bool = True, raising: bool = True,
) -> None: ) -> None:
"""Delete attribute ``name`` from ``target``. """Delete attribute ``name`` from ``target``.
@ -309,7 +307,7 @@ class MonkeyPatch:
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict # Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dic[name] # type: ignore[attr-defined] del dic[name] # type: ignore[attr-defined]
def setenv(self, name: str, value: str, prepend: Optional[str] = None) -> None: def setenv(self, name: str, value: str, prepend: str | None = None) -> None:
"""Set environment variable ``name`` to ``value``. """Set environment variable ``name`` to ``value``.
If ``prepend`` is a character, read the current environment variable If ``prepend`` is a character, read the current environment variable
@ -362,7 +360,7 @@ class MonkeyPatch:
invalidate_caches() invalidate_caches()
def chdir(self, path: Union[str, "os.PathLike[str]"]) -> None: def chdir(self, path: str | os.PathLike[str]) -> None:
"""Change the current working directory to the specified path. """Change the current working directory to the specified path.
:param path: :param path:

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import abc import abc
from functools import cached_property from functools import cached_property
from inspect import signature from inspect import signature
@ -10,17 +12,11 @@ from typing import Callable
from typing import cast from typing import cast
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import MutableMapping from typing import MutableMapping
from typing import NoReturn from typing import NoReturn
from typing import Optional
from typing import overload from typing import overload
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union
import warnings import warnings
import pluggy import pluggy
@ -62,9 +58,9 @@ _T = TypeVar("_T")
def _imply_path( def _imply_path(
node_type: Type["Node"], node_type: type[Node],
path: Optional[Path], path: Path | None,
fspath: Optional[LEGACY_PATH], fspath: LEGACY_PATH | None,
) -> Path: ) -> Path:
if fspath is not None: if fspath is not None:
warnings.warn( warnings.warn(
@ -109,7 +105,7 @@ class NodeMeta(abc.ABCMeta):
).format(name=f"{cls.__module__}.{cls.__name__}") ).format(name=f"{cls.__module__}.{cls.__name__}")
fail(msg, pytrace=False) fail(msg, pytrace=False)
def _create(cls: Type[_T], *k, **kw) -> _T: def _create(cls: type[_T], *k, **kw) -> _T:
try: try:
return super().__call__(*k, **kw) # type: ignore[no-any-return,misc] return super().__call__(*k, **kw) # type: ignore[no-any-return,misc]
except TypeError: except TypeError:
@ -160,12 +156,12 @@ class Node(abc.ABC, metaclass=NodeMeta):
def __init__( def __init__(
self, self,
name: str, name: str,
parent: "Optional[Node]" = None, parent: Node | None = None,
config: Optional[Config] = None, config: Config | None = None,
session: "Optional[Session]" = None, session: Session | None = None,
fspath: Optional[LEGACY_PATH] = None, fspath: LEGACY_PATH | None = None,
path: Optional[Path] = None, path: Path | None = None,
nodeid: Optional[str] = None, nodeid: str | None = None,
) -> None: ) -> None:
#: A unique name within the scope of the parent node. #: A unique name within the scope of the parent node.
self.name: str = name self.name: str = name
@ -199,10 +195,10 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.keywords: MutableMapping[str, Any] = NodeKeywords(self) self.keywords: MutableMapping[str, Any] = NodeKeywords(self)
#: The marker objects belonging to this node. #: The marker objects belonging to this node.
self.own_markers: List[Mark] = [] self.own_markers: list[Mark] = []
#: Allow adding of extra keywords to use for matching. #: Allow adding of extra keywords to use for matching.
self.extra_keyword_matches: Set[str] = set() self.extra_keyword_matches: set[str] = set()
if nodeid is not None: if nodeid is not None:
assert "::()" not in nodeid assert "::()" not in nodeid
@ -219,7 +215,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
self._store = self.stash self._store = self.stash
@classmethod @classmethod
def from_parent(cls, parent: "Node", **kw) -> "Self": def from_parent(cls, parent: Node, **kw) -> Self:
"""Public constructor for Nodes. """Public constructor for Nodes.
This indirection got introduced in order to enable removing This indirection got introduced in order to enable removing
@ -295,31 +291,29 @@ class Node(abc.ABC, metaclass=NodeMeta):
def teardown(self) -> None: def teardown(self) -> None:
pass pass
def iter_parents(self) -> Iterator["Node"]: def iter_parents(self) -> Iterator[Node]:
"""Iterate over all parent collectors starting from and including self """Iterate over all parent collectors starting from and including self
up to the root of the collection tree. up to the root of the collection tree.
.. versionadded:: 8.1 .. versionadded:: 8.1
""" """
parent: Optional[Node] = self parent: Node | None = self
while parent is not None: while parent is not None:
yield parent yield parent
parent = parent.parent parent = parent.parent
def listchain(self) -> List["Node"]: def listchain(self) -> list[Node]:
"""Return a list of all parent collectors starting from the root of the """Return a list of all parent collectors starting from the root of the
collection tree down to and including self.""" collection tree down to and including self."""
chain = [] chain = []
item: Optional[Node] = self item: Node | None = self
while item is not None: while item is not None:
chain.append(item) chain.append(item)
item = item.parent item = item.parent
chain.reverse() chain.reverse()
return chain return chain
def add_marker( def add_marker(self, marker: str | MarkDecorator, append: bool = True) -> None:
self, marker: Union[str, MarkDecorator], append: bool = True
) -> None:
"""Dynamically add a marker object to the node. """Dynamically add a marker object to the node.
:param marker: :param marker:
@ -341,7 +335,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
else: else:
self.own_markers.insert(0, marker_.mark) self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]: def iter_markers(self, name: str | None = None) -> Iterator[Mark]:
"""Iterate over all markers of the node. """Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute. :param name: If given, filter the results by the name attribute.
@ -350,8 +344,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
return (x[1] for x in self.iter_markers_with_node(name=name)) return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node( def iter_markers_with_node(
self, name: Optional[str] = None self, name: str | None = None
) -> Iterator[Tuple["Node", Mark]]: ) -> Iterator[tuple[Node, Mark]]:
"""Iterate over all markers of the node. """Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute. :param name: If given, filter the results by the name attribute.
@ -363,14 +357,12 @@ class Node(abc.ABC, metaclass=NodeMeta):
yield node, mark yield node, mark
@overload @overload
def get_closest_marker(self, name: str) -> Optional[Mark]: ... def get_closest_marker(self, name: str) -> Mark | None: ...
@overload @overload
def get_closest_marker(self, name: str, default: Mark) -> Mark: ... def get_closest_marker(self, name: str, default: Mark) -> Mark: ...
def get_closest_marker( def get_closest_marker(self, name: str, default: Mark | None = None) -> Mark | None:
self, name: str, default: Optional[Mark] = None
) -> Optional[Mark]:
"""Return the first marker matching the name, from closest (for """Return the first marker matching the name, from closest (for
example function) to farther level (for example module level). example function) to farther level (for example module level).
@ -379,14 +371,14 @@ class Node(abc.ABC, metaclass=NodeMeta):
""" """
return next(self.iter_markers(name=name), default) return next(self.iter_markers(name=name), default)
def listextrakeywords(self) -> Set[str]: def listextrakeywords(self) -> set[str]:
"""Return a set of all extra keywords in self and any parents.""" """Return a set of all extra keywords in self and any parents."""
extra_keywords: Set[str] = set() extra_keywords: set[str] = set()
for item in self.listchain(): for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches) extra_keywords.update(item.extra_keyword_matches)
return extra_keywords return extra_keywords
def listnames(self) -> List[str]: def listnames(self) -> list[str]:
return [x.name for x in self.listchain()] return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None: def addfinalizer(self, fin: Callable[[], object]) -> None:
@ -398,7 +390,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
""" """
self.session._setupstate.addfinalizer(fin, self) self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]: def getparent(self, cls: type[_NodeType]) -> _NodeType | None:
"""Get the closest parent node (including self) which is an instance of """Get the closest parent node (including self) which is an instance of
the given class. the given class.
@ -416,7 +408,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
def _repr_failure_py( def _repr_failure_py(
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None, style: _TracebackStyle | None = None,
) -> TerminalRepr: ) -> TerminalRepr:
from _pytest.fixtures import FixtureLookupError from _pytest.fixtures import FixtureLookupError
@ -428,7 +420,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
if isinstance(excinfo.value, FixtureLookupError): if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr() return excinfo.value.formatrepr()
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]] tbfilter: bool | Callable[[ExceptionInfo[BaseException]], Traceback]
if self.config.getoption("fulltrace", False): if self.config.getoption("fulltrace", False):
style = "long" style = "long"
tbfilter = False tbfilter = False
@ -471,8 +463,8 @@ class Node(abc.ABC, metaclass=NodeMeta):
def repr_failure( def repr_failure(
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None, style: _TracebackStyle | None = None,
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
"""Return a representation of a collection or test failure. """Return a representation of a collection or test failure.
.. seealso:: :ref:`non-python tests` .. seealso:: :ref:`non-python tests`
@ -482,7 +474,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
return self._repr_failure_py(excinfo, style) return self._repr_failure_py(excinfo, style)
def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[int]]: def get_fslocation_from_item(node: Node) -> tuple[str | Path, int | None]:
"""Try to extract the actual location from a node, depending on available attributes: """Try to extract the actual location from a node, depending on available attributes:
* "location": a pair (path, lineno) * "location": a pair (path, lineno)
@ -492,7 +484,7 @@ def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[i
:rtype: A tuple of (str|Path, int) with filename and 0-based line number. :rtype: A tuple of (str|Path, int) with filename and 0-based line number.
""" """
# See Item.location. # See Item.location.
location: Optional[Tuple[str, Optional[int], str]] = getattr(node, "location", None) location: tuple[str, int | None, str] | None = getattr(node, "location", None)
if location is not None: if location is not None:
return location[:2] return location[:2]
obj = getattr(node, "obj", None) obj = getattr(node, "obj", None)
@ -512,14 +504,14 @@ class Collector(Node, abc.ABC):
"""An error during collection, contains a custom message.""" """An error during collection, contains a custom message."""
@abc.abstractmethod @abc.abstractmethod
def collect(self) -> Iterable[Union["Item", "Collector"]]: def collect(self) -> Iterable[Item | Collector]:
"""Collect children (items and collectors) for this collector.""" """Collect children (items and collectors) for this collector."""
raise NotImplementedError("abstract") raise NotImplementedError("abstract")
# TODO: This omits the style= parameter which breaks Liskov Substitution. # TODO: This omits the style= parameter which breaks Liskov Substitution.
def repr_failure( # type: ignore[override] def repr_failure( # type: ignore[override]
self, excinfo: ExceptionInfo[BaseException] self, excinfo: ExceptionInfo[BaseException]
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
"""Return a representation of a collection failure. """Return a representation of a collection failure.
:param excinfo: Exception information for the failure. :param excinfo: Exception information for the failure.
@ -548,7 +540,7 @@ class Collector(Node, abc.ABC):
return excinfo.traceback return excinfo.traceback
def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]: def _check_initialpaths_for_relpath(session: Session, path: Path) -> str | None:
for initial_path in session._initialpaths: for initial_path in session._initialpaths:
if commonpath(path, initial_path) == initial_path: if commonpath(path, initial_path) == initial_path:
rel = str(path.relative_to(initial_path)) rel = str(path.relative_to(initial_path))
@ -561,14 +553,14 @@ class FSCollector(Collector, abc.ABC):
def __init__( def __init__(
self, self,
fspath: Optional[LEGACY_PATH] = None, fspath: LEGACY_PATH | None = None,
path_or_parent: Optional[Union[Path, Node]] = None, path_or_parent: Path | Node | None = None,
path: Optional[Path] = None, path: Path | None = None,
name: Optional[str] = None, name: str | None = None,
parent: Optional[Node] = None, parent: Node | None = None,
config: Optional[Config] = None, config: Config | None = None,
session: Optional["Session"] = None, session: Session | None = None,
nodeid: Optional[str] = None, nodeid: str | None = None,
) -> None: ) -> None:
if path_or_parent: if path_or_parent:
if isinstance(path_or_parent, Node): if isinstance(path_or_parent, Node):
@ -618,10 +610,10 @@ class FSCollector(Collector, abc.ABC):
cls, cls,
parent, parent,
*, *,
fspath: Optional[LEGACY_PATH] = None, fspath: LEGACY_PATH | None = None,
path: Optional[Path] = None, path: Path | None = None,
**kw, **kw,
) -> "Self": ) -> Self:
"""The public constructor.""" """The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, path=path, **kw) return super().from_parent(parent=parent, fspath=fspath, path=path, **kw)
@ -663,9 +655,9 @@ class Item(Node, abc.ABC):
self, self,
name, name,
parent=None, parent=None,
config: Optional[Config] = None, config: Config | None = None,
session: Optional["Session"] = None, session: Session | None = None,
nodeid: Optional[str] = None, nodeid: str | None = None,
**kw, **kw,
) -> None: ) -> None:
# The first two arguments are intentionally passed positionally, # The first two arguments are intentionally passed positionally,
@ -680,11 +672,11 @@ class Item(Node, abc.ABC):
nodeid=nodeid, nodeid=nodeid,
**kw, **kw,
) )
self._report_sections: List[Tuple[str, str, str]] = [] self._report_sections: list[tuple[str, str, str]] = []
#: A list of tuples (name, value) that holds user defined properties #: A list of tuples (name, value) that holds user defined properties
#: for this test. #: for this test.
self.user_properties: List[Tuple[str, object]] = [] self.user_properties: list[tuple[str, object]] = []
self._check_item_and_collector_diamond_inheritance() self._check_item_and_collector_diamond_inheritance()
@ -744,7 +736,7 @@ class Item(Node, abc.ABC):
if content: if content:
self._report_sections.append((when, key, content)) self._report_sections.append((when, key, content))
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]: def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
"""Get location information for this item for test reports. """Get location information for this item for test reports.
Returns a tuple with three elements: Returns a tuple with three elements:
@ -758,7 +750,7 @@ class Item(Node, abc.ABC):
return self.path, None, "" return self.path, None, ""
@cached_property @cached_property
def location(self) -> Tuple[str, Optional[int], str]: def location(self) -> tuple[str, int | None, str]:
""" """
Returns a tuple of ``(relfspath, lineno, testname)`` for this item Returns a tuple of ``(relfspath, lineno, testname)`` for this item
where ``relfspath`` is file path relative to ``config.rootpath`` where ``relfspath`` is file path relative to ``config.rootpath``

View File

@ -1,12 +1,13 @@
"""Exception classes and constants handling test outcomes as well as """Exception classes and constants handling test outcomes as well as
functions creating them.""" functions creating them."""
from __future__ import annotations
import sys import sys
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import NoReturn from typing import NoReturn
from typing import Optional
from typing import Protocol from typing import Protocol
from typing import Type from typing import Type
from typing import TypeVar from typing import TypeVar
@ -18,7 +19,7 @@ class OutcomeException(BaseException):
"""OutcomeException and its subclass instances indicate and contain info """OutcomeException and its subclass instances indicate and contain info
about test and collection outcomes.""" about test and collection outcomes."""
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None: def __init__(self, msg: str | None = None, pytrace: bool = True) -> None:
if msg is not None and not isinstance(msg, str): if msg is not None and not isinstance(msg, str):
error_msg = ( # type: ignore[unreachable] error_msg = ( # type: ignore[unreachable]
"{} expected string as 'msg' parameter, got '{}' instead.\n" "{} expected string as 'msg' parameter, got '{}' instead.\n"
@ -47,7 +48,7 @@ class Skipped(OutcomeException):
def __init__( def __init__(
self, self,
msg: Optional[str] = None, msg: str | None = None,
pytrace: bool = True, pytrace: bool = True,
allow_module_level: bool = False, allow_module_level: bool = False,
*, *,
@ -70,7 +71,7 @@ class Exit(Exception):
"""Raised for immediate program exits (no tracebacks/summaries).""" """Raised for immediate program exits (no tracebacks/summaries)."""
def __init__( def __init__(
self, msg: str = "unknown reason", returncode: Optional[int] = None self, msg: str = "unknown reason", returncode: int | None = None
) -> None: ) -> None:
self.msg = msg self.msg = msg
self.returncode = returncode self.returncode = returncode
@ -104,7 +105,7 @@ def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _E
@_with_exception(Exit) @_with_exception(Exit)
def exit( def exit(
reason: str = "", reason: str = "",
returncode: Optional[int] = None, returncode: int | None = None,
) -> NoReturn: ) -> NoReturn:
"""Exit testing process. """Exit testing process.
@ -207,10 +208,10 @@ def xfail(reason: str = "") -> NoReturn:
def importorskip( def importorskip(
modname: str, modname: str,
minversion: Optional[str] = None, minversion: str | None = None,
reason: Optional[str] = None, reason: str | None = None,
*, *,
exc_type: Optional[Type[ImportError]] = None, exc_type: type[ImportError] | None = None,
) -> Any: ) -> Any:
"""Import and return the requested module ``modname``, or skip the """Import and return the requested module ``modname``, or skip the
current test if the module cannot be imported. current test if the module cannot be imported.
@ -267,8 +268,8 @@ def importorskip(
else: else:
warn_on_import_error = False warn_on_import_error = False
skipped: Optional[Skipped] = None skipped: Skipped | None = None
warning: Optional[Warning] = None warning: Warning | None = None
with warnings.catch_warnings(): with warnings.catch_warnings():
# Make sure to ignore ImportWarnings that might happen because # Make sure to ignore ImportWarnings that might happen because

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Submit failure or test session information to a pastebin service.""" """Submit failure or test session information to a pastebin service."""
from __future__ import annotations
from io import StringIO from io import StringIO
import tempfile import tempfile
from typing import IO from typing import IO
from typing import Union
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import create_terminal_writer from _pytest.config import create_terminal_writer
@ -68,7 +69,7 @@ def pytest_unconfigure(config: Config) -> None:
tr.write_line("pastebin session-log: %s\n" % pastebinurl) tr.write_line("pastebin session-log: %s\n" % pastebinurl)
def create_new_paste(contents: Union[str, bytes]) -> str: def create_new_paste(contents: str | bytes) -> str:
"""Create a new paste using the bpaste.net service. """Create a new paste using the bpaste.net service.
:contents: Paste contents string. :contents: Paste contents string.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import atexit import atexit
import contextlib import contextlib
from enum import Enum from enum import Enum
@ -24,16 +26,9 @@ import types
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar from typing import TypeVar
from typing import Union
import uuid import uuid
import warnings import warnings
@ -71,12 +66,10 @@ def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
def on_rm_rf_error( def on_rm_rf_error(
func: Optional[Callable[..., Any]], func: Callable[..., Any] | None,
path: str, path: str,
excinfo: Union[ excinfo: BaseException
BaseException, | tuple[type[BaseException], BaseException, types.TracebackType | None],
Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]],
],
*, *,
start_path: Path, start_path: Path,
) -> bool: ) -> bool:
@ -172,7 +165,7 @@ def rm_rf(path: Path) -> None:
shutil.rmtree(str(path), onerror=onerror) shutil.rmtree(str(path), onerror=onerror)
def find_prefixed(root: Path, prefix: str) -> Iterator["os.DirEntry[str]"]: def find_prefixed(root: Path, prefix: str) -> Iterator[os.DirEntry[str]]:
"""Find all elements in root that begin with the prefix, case-insensitive.""" """Find all elements in root that begin with the prefix, case-insensitive."""
l_prefix = prefix.lower() l_prefix = prefix.lower()
for x in os.scandir(root): for x in os.scandir(root):
@ -180,7 +173,7 @@ def find_prefixed(root: Path, prefix: str) -> Iterator["os.DirEntry[str]"]:
yield x yield x
def extract_suffixes(iter: Iterable["os.DirEntry[str]"], prefix: str) -> Iterator[str]: def extract_suffixes(iter: Iterable[os.DirEntry[str]], prefix: str) -> Iterator[str]:
"""Return the parts of the paths following the prefix. """Return the parts of the paths following the prefix.
:param iter: Iterator over path names. :param iter: Iterator over path names.
@ -204,9 +197,7 @@ def parse_num(maybe_num: str) -> int:
return -1 return -1
def _force_symlink( def _force_symlink(root: Path, target: str | PurePath, link_to: str | Path) -> None:
root: Path, target: Union[str, PurePath], link_to: Union[str, Path]
) -> None:
"""Helper to create the current symlink. """Helper to create the current symlink.
It's full of race conditions that are reasonably OK to ignore It's full of race conditions that are reasonably OK to ignore
@ -420,7 +411,7 @@ def resolve_from_str(input: str, rootpath: Path) -> Path:
return rootpath.joinpath(input) return rootpath.joinpath(input)
def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool: def fnmatch_ex(pattern: str, path: str | os.PathLike[str]) -> bool:
"""A port of FNMatcher from py.path.common which works with PurePath() instances. """A port of FNMatcher from py.path.common which works with PurePath() instances.
The difference between this algorithm and PurePath.match() is that the The difference between this algorithm and PurePath.match() is that the
@ -456,14 +447,14 @@ def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
return fnmatch.fnmatch(name, pattern) return fnmatch.fnmatch(name, pattern)
def parts(s: str) -> Set[str]: def parts(s: str) -> set[str]:
parts = s.split(sep) parts = s.split(sep)
return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))} return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))}
def symlink_or_skip( def symlink_or_skip(
src: Union["os.PathLike[str]", str], src: os.PathLike[str] | str,
dst: Union["os.PathLike[str]", str], dst: os.PathLike[str] | str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Make a symlink, or skip the test in case symlinks are not supported.""" """Make a symlink, or skip the test in case symlinks are not supported."""
@ -491,9 +482,9 @@ class ImportPathMismatchError(ImportError):
def import_path( def import_path(
path: Union[str, "os.PathLike[str]"], path: str | os.PathLike[str],
*, *,
mode: Union[str, ImportMode] = ImportMode.prepend, mode: str | ImportMode = ImportMode.prepend,
root: Path, root: Path,
consider_namespace_packages: bool, consider_namespace_packages: bool,
) -> ModuleType: ) -> ModuleType:
@ -618,7 +609,7 @@ def import_path(
def _import_module_using_spec( def _import_module_using_spec(
module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool module_name: str, module_path: Path, module_location: Path, *, insert_modules: bool
) -> Optional[ModuleType]: ) -> ModuleType | None:
""" """
Tries to import a module by its canonical name, path to the .py file, and its Tries to import a module by its canonical name, path to the .py file, and its
parent location. parent location.
@ -641,7 +632,7 @@ def _import_module_using_spec(
# Attempt to import the parent module, seems is our responsibility: # Attempt to import the parent module, seems is our responsibility:
# https://github.com/python/cpython/blob/73906d5c908c1e0b73c5436faeff7d93698fc074/Lib/importlib/_bootstrap.py#L1308-L1311 # https://github.com/python/cpython/blob/73906d5c908c1e0b73c5436faeff7d93698fc074/Lib/importlib/_bootstrap.py#L1308-L1311
parent_module_name, _, name = module_name.rpartition(".") parent_module_name, _, name = module_name.rpartition(".")
parent_module: Optional[ModuleType] = None parent_module: ModuleType | None = None
if parent_module_name: if parent_module_name:
parent_module = sys.modules.get(parent_module_name) parent_module = sys.modules.get(parent_module_name)
if parent_module is None: if parent_module is None:
@ -680,9 +671,7 @@ def _import_module_using_spec(
return None return None
def spec_matches_module_path( def spec_matches_module_path(module_spec: ModuleSpec | None, module_path: Path) -> bool:
module_spec: Optional[ModuleSpec], module_path: Path
) -> bool:
"""Return true if the given ModuleSpec can be used to import the given module path.""" """Return true if the given ModuleSpec can be used to import the given module path."""
if module_spec is None or module_spec.origin is None: if module_spec is None or module_spec.origin is None:
return False return False
@ -734,7 +723,7 @@ def module_name_from_path(path: Path, root: Path) -> str:
return ".".join(path_parts) return ".".join(path_parts)
def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) -> None: def insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None:
""" """
Used by ``import_path`` to create intermediate modules when using mode=importlib. Used by ``import_path`` to create intermediate modules when using mode=importlib.
@ -772,7 +761,7 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
module_name = ".".join(module_parts) module_name = ".".join(module_parts)
def resolve_package_path(path: Path) -> Optional[Path]: def resolve_package_path(path: Path) -> Path | None:
"""Return the Python package path by looking for the last """Return the Python package path by looking for the last
directory upwards which still contains an __init__.py. directory upwards which still contains an __init__.py.
@ -791,7 +780,7 @@ def resolve_package_path(path: Path) -> Optional[Path]:
def resolve_pkg_root_and_module_name( def resolve_pkg_root_and_module_name(
path: Path, *, consider_namespace_packages: bool = False path: Path, *, consider_namespace_packages: bool = False
) -> Tuple[Path, str]: ) -> tuple[Path, str]:
""" """
Return the path to the directory of the root package that contains the Return the path to the directory of the root package that contains the
given Python file, and its module name: given Python file, and its module name:
@ -812,7 +801,7 @@ def resolve_pkg_root_and_module_name(
Raises CouldNotResolvePathError if the given path does not belong to a package (missing any __init__.py files). Raises CouldNotResolvePathError if the given path does not belong to a package (missing any __init__.py files).
""" """
pkg_root: Optional[Path] = None pkg_root: Path | None = None
pkg_path = resolve_package_path(path) pkg_path = resolve_package_path(path)
if pkg_path is not None: if pkg_path is not None:
pkg_root = pkg_path.parent pkg_root = pkg_path.parent
@ -859,7 +848,7 @@ def is_importable(module_name: str, module_path: Path) -> bool:
return spec_matches_module_path(spec, module_path) return spec_matches_module_path(spec, module_path)
def compute_module_name(root: Path, module_path: Path) -> Optional[str]: def compute_module_name(root: Path, module_path: Path) -> str | None:
"""Compute a module name based on a path and a root anchor.""" """Compute a module name based on a path and a root anchor."""
try: try:
path_without_suffix = module_path.with_suffix("") path_without_suffix = module_path.with_suffix("")
@ -884,9 +873,9 @@ class CouldNotResolvePathError(Exception):
def scandir( def scandir(
path: Union[str, "os.PathLike[str]"], path: str | os.PathLike[str],
sort_key: Callable[["os.DirEntry[str]"], object] = lambda entry: entry.name, sort_key: Callable[[os.DirEntry[str]], object] = lambda entry: entry.name,
) -> List["os.DirEntry[str]"]: ) -> list[os.DirEntry[str]]:
"""Scan a directory recursively, in breadth-first order. """Scan a directory recursively, in breadth-first order.
The returned entries are sorted according to the given key. The returned entries are sorted according to the given key.
@ -909,8 +898,8 @@ def scandir(
def visit( def visit(
path: Union[str, "os.PathLike[str]"], recurse: Callable[["os.DirEntry[str]"], bool] path: str | os.PathLike[str], recurse: Callable[[os.DirEntry[str]], bool]
) -> Iterator["os.DirEntry[str]"]: ) -> Iterator[os.DirEntry[str]]:
"""Walk a directory recursively, in breadth-first order. """Walk a directory recursively, in breadth-first order.
The `recurse` predicate determines whether a directory is recursed. The `recurse` predicate determines whether a directory is recursed.
@ -924,7 +913,7 @@ def visit(
yield from visit(entry.path, recurse) yield from visit(entry.path, recurse)
def absolutepath(path: "Union[str, os.PathLike[str]]") -> Path: def absolutepath(path: str | os.PathLike[str]) -> Path:
"""Convert a path to an absolute path using os.path.abspath. """Convert a path to an absolute path using os.path.abspath.
Prefer this over Path.resolve() (see #6523). Prefer this over Path.resolve() (see #6523).
@ -933,7 +922,7 @@ def absolutepath(path: "Union[str, os.PathLike[str]]") -> Path:
return Path(os.path.abspath(path)) return Path(os.path.abspath(path))
def commonpath(path1: Path, path2: Path) -> Optional[Path]: def commonpath(path1: Path, path2: Path) -> Path | None:
"""Return the common part shared with the other path, or None if there is """Return the common part shared with the other path, or None if there is
no common part. no common part.

View File

@ -4,6 +4,8 @@
PYTEST_DONT_REWRITE PYTEST_DONT_REWRITE
""" """
from __future__ import annotations
import collections.abc import collections.abc
import contextlib import contextlib
from fnmatch import fnmatch from fnmatch import fnmatch
@ -21,22 +23,16 @@ import sys
import traceback import traceback
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict
from typing import Final from typing import Final
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import IO from typing import IO
from typing import Iterable from typing import Iterable
from typing import List
from typing import Literal from typing import Literal
from typing import Optional
from typing import overload from typing import overload
from typing import Sequence from typing import Sequence
from typing import TextIO from typing import TextIO
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
from iniconfig import IniConfig from iniconfig import IniConfig
@ -123,7 +119,7 @@ def pytest_configure(config: Config) -> None:
class LsofFdLeakChecker: class LsofFdLeakChecker:
def get_open_files(self) -> List[Tuple[str, str]]: def get_open_files(self) -> list[tuple[str, str]]:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
# New in Python 3.11, ignores utf-8 mode # New in Python 3.11, ignores utf-8 mode
encoding = locale.getencoding() encoding = locale.getencoding()
@ -199,7 +195,7 @@ class LsofFdLeakChecker:
@fixture @fixture
def _pytest(request: FixtureRequest) -> "PytestArg": def _pytest(request: FixtureRequest) -> PytestArg:
"""Return a helper which offers a gethookrecorder(hook) method which """Return a helper which offers a gethookrecorder(hook) method which
returns a HookRecorder instance which helps to make assertions about called returns a HookRecorder instance which helps to make assertions about called
hooks.""" hooks."""
@ -210,13 +206,13 @@ class PytestArg:
def __init__(self, request: FixtureRequest) -> None: def __init__(self, request: FixtureRequest) -> None:
self._request = request self._request = request
def gethookrecorder(self, hook) -> "HookRecorder": def gethookrecorder(self, hook) -> HookRecorder:
hookrecorder = HookRecorder(hook._pm) hookrecorder = HookRecorder(hook._pm)
self._request.addfinalizer(hookrecorder.finish_recording) self._request.addfinalizer(hookrecorder.finish_recording)
return hookrecorder return hookrecorder
def get_public_names(values: Iterable[str]) -> List[str]: def get_public_names(values: Iterable[str]) -> list[str]:
"""Only return names from iterator values without a leading underscore.""" """Only return names from iterator values without a leading underscore."""
return [x for x in values if x[0] != "_"] return [x for x in values if x[0] != "_"]
@ -265,8 +261,8 @@ class HookRecorder:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._pluginmanager = pluginmanager self._pluginmanager = pluginmanager
self.calls: List[RecordedHookCall] = [] self.calls: list[RecordedHookCall] = []
self.ret: Optional[Union[int, ExitCode]] = None self.ret: int | ExitCode | None = None
def before(hook_name: str, hook_impls, kwargs) -> None: def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(RecordedHookCall(hook_name, kwargs)) self.calls.append(RecordedHookCall(hook_name, kwargs))
@ -279,13 +275,13 @@ class HookRecorder:
def finish_recording(self) -> None: def finish_recording(self) -> None:
self._undo_wrapping() self._undo_wrapping()
def getcalls(self, names: Union[str, Iterable[str]]) -> List[RecordedHookCall]: def getcalls(self, names: str | Iterable[str]) -> list[RecordedHookCall]:
"""Get all recorded calls to hooks with the given names (or name).""" """Get all recorded calls to hooks with the given names (or name)."""
if isinstance(names, str): if isinstance(names, str):
names = names.split() names = names.split()
return [call for call in self.calls if call._name in names] return [call for call in self.calls if call._name in names]
def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None: def assert_contains(self, entries: Sequence[tuple[str, str]]) -> None:
__tracebackhide__ = True __tracebackhide__ = True
i = 0 i = 0
entries = list(entries) entries = list(entries)
@ -327,42 +323,42 @@ class HookRecorder:
@overload @overload
def getreports( def getreports(
self, self,
names: "Literal['pytest_collectreport']", names: Literal["pytest_collectreport"],
) -> Sequence[CollectReport]: ... ) -> Sequence[CollectReport]: ...
@overload @overload
def getreports( def getreports(
self, self,
names: "Literal['pytest_runtest_logreport']", names: Literal["pytest_runtest_logreport"],
) -> Sequence[TestReport]: ... ) -> Sequence[TestReport]: ...
@overload @overload
def getreports( def getreports(
self, self,
names: Union[str, Iterable[str]] = ( names: str | Iterable[str] = (
"pytest_collectreport", "pytest_collectreport",
"pytest_runtest_logreport", "pytest_runtest_logreport",
), ),
) -> Sequence[Union[CollectReport, TestReport]]: ... ) -> Sequence[CollectReport | TestReport]: ...
def getreports( def getreports(
self, self,
names: Union[str, Iterable[str]] = ( names: str | Iterable[str] = (
"pytest_collectreport", "pytest_collectreport",
"pytest_runtest_logreport", "pytest_runtest_logreport",
), ),
) -> Sequence[Union[CollectReport, TestReport]]: ) -> Sequence[CollectReport | TestReport]:
return [x.report for x in self.getcalls(names)] return [x.report for x in self.getcalls(names)]
def matchreport( def matchreport(
self, self,
inamepart: str = "", inamepart: str = "",
names: Union[str, Iterable[str]] = ( names: str | Iterable[str] = (
"pytest_runtest_logreport", "pytest_runtest_logreport",
"pytest_collectreport", "pytest_collectreport",
), ),
when: Optional[str] = None, when: str | None = None,
) -> Union[CollectReport, TestReport]: ) -> CollectReport | TestReport:
"""Return a testreport whose dotted import path matches.""" """Return a testreport whose dotted import path matches."""
values = [] values = []
for rep in self.getreports(names=names): for rep in self.getreports(names=names):
@ -387,31 +383,31 @@ class HookRecorder:
@overload @overload
def getfailures( def getfailures(
self, self,
names: "Literal['pytest_collectreport']", names: Literal["pytest_collectreport"],
) -> Sequence[CollectReport]: ... ) -> Sequence[CollectReport]: ...
@overload @overload
def getfailures( def getfailures(
self, self,
names: "Literal['pytest_runtest_logreport']", names: Literal["pytest_runtest_logreport"],
) -> Sequence[TestReport]: ... ) -> Sequence[TestReport]: ...
@overload @overload
def getfailures( def getfailures(
self, self,
names: Union[str, Iterable[str]] = ( names: str | Iterable[str] = (
"pytest_collectreport", "pytest_collectreport",
"pytest_runtest_logreport", "pytest_runtest_logreport",
), ),
) -> Sequence[Union[CollectReport, TestReport]]: ... ) -> Sequence[CollectReport | TestReport]: ...
def getfailures( def getfailures(
self, self,
names: Union[str, Iterable[str]] = ( names: str | Iterable[str] = (
"pytest_collectreport", "pytest_collectreport",
"pytest_runtest_logreport", "pytest_runtest_logreport",
), ),
) -> Sequence[Union[CollectReport, TestReport]]: ) -> Sequence[CollectReport | TestReport]:
return [rep for rep in self.getreports(names) if rep.failed] return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self) -> Sequence[CollectReport]: def getfailedcollections(self) -> Sequence[CollectReport]:
@ -419,10 +415,10 @@ class HookRecorder:
def listoutcomes( def listoutcomes(
self, self,
) -> Tuple[ ) -> tuple[
Sequence[TestReport], Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]], Sequence[CollectReport | TestReport],
Sequence[Union[CollectReport, TestReport]], Sequence[CollectReport | TestReport],
]: ]:
passed = [] passed = []
skipped = [] skipped = []
@ -441,7 +437,7 @@ class HookRecorder:
failed.append(rep) failed.append(rep)
return passed, skipped, failed return passed, skipped, failed
def countoutcomes(self) -> List[int]: def countoutcomes(self) -> list[int]:
return [len(x) for x in self.listoutcomes()] return [len(x) for x in self.listoutcomes()]
def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None: def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
@ -461,14 +457,14 @@ class HookRecorder:
@fixture @fixture
def linecomp() -> "LineComp": def linecomp() -> LineComp:
"""A :class: `LineComp` instance for checking that an input linearly """A :class: `LineComp` instance for checking that an input linearly
contains a sequence of strings.""" contains a sequence of strings."""
return LineComp() return LineComp()
@fixture(name="LineMatcher") @fixture(name="LineMatcher")
def LineMatcher_fixture(request: FixtureRequest) -> Type["LineMatcher"]: def LineMatcher_fixture(request: FixtureRequest) -> type[LineMatcher]:
"""A reference to the :class: `LineMatcher`. """A reference to the :class: `LineMatcher`.
This is instantiable with a list of lines (without their trailing newlines). This is instantiable with a list of lines (without their trailing newlines).
@ -480,7 +476,7 @@ def LineMatcher_fixture(request: FixtureRequest) -> Type["LineMatcher"]:
@fixture @fixture
def pytester( def pytester(
request: FixtureRequest, tmp_path_factory: TempPathFactory, monkeypatch: MonkeyPatch request: FixtureRequest, tmp_path_factory: TempPathFactory, monkeypatch: MonkeyPatch
) -> "Pytester": ) -> Pytester:
""" """
Facilities to write tests/configuration files, execute pytest in isolation, and match Facilities to write tests/configuration files, execute pytest in isolation, and match
against expected output, perfect for black-box testing of pytest plugins. against expected output, perfect for black-box testing of pytest plugins.
@ -524,13 +520,13 @@ class RunResult:
def __init__( def __init__(
self, self,
ret: Union[int, ExitCode], ret: int | ExitCode,
outlines: List[str], outlines: list[str],
errlines: List[str], errlines: list[str],
duration: float, duration: float,
) -> None: ) -> None:
try: try:
self.ret: Union[int, ExitCode] = ExitCode(ret) self.ret: int | ExitCode = ExitCode(ret)
"""The return value.""" """The return value."""
except ValueError: except ValueError:
self.ret = ret self.ret = ret
@ -555,7 +551,7 @@ class RunResult:
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration) % (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
) )
def parseoutcomes(self) -> Dict[str, int]: def parseoutcomes(self) -> dict[str, int]:
"""Return a dictionary of outcome noun -> count from parsing the terminal """Return a dictionary of outcome noun -> count from parsing the terminal
output that the test process produced. output that the test process produced.
@ -568,7 +564,7 @@ class RunResult:
return self.parse_summary_nouns(self.outlines) return self.parse_summary_nouns(self.outlines)
@classmethod @classmethod
def parse_summary_nouns(cls, lines) -> Dict[str, int]: def parse_summary_nouns(cls, lines) -> dict[str, int]:
"""Extract the nouns from a pytest terminal summary line. """Extract the nouns from a pytest terminal summary line.
It always returns the plural noun for consistency:: It always returns the plural noun for consistency::
@ -599,8 +595,8 @@ class RunResult:
errors: int = 0, errors: int = 0,
xpassed: int = 0, xpassed: int = 0,
xfailed: int = 0, xfailed: int = 0,
warnings: Optional[int] = None, warnings: int | None = None,
deselected: Optional[int] = None, deselected: int | None = None,
) -> None: ) -> None:
""" """
Assert that the specified outcomes appear with the respective Assert that the specified outcomes appear with the respective
@ -626,7 +622,7 @@ class RunResult:
class SysModulesSnapshot: class SysModulesSnapshot:
def __init__(self, preserve: Optional[Callable[[str], bool]] = None) -> None: def __init__(self, preserve: Callable[[str], bool] | None = None) -> None:
self.__preserve = preserve self.__preserve = preserve
self.__saved = dict(sys.modules) self.__saved = dict(sys.modules)
@ -659,7 +655,7 @@ class Pytester:
__test__ = False __test__ = False
CLOSE_STDIN: "Final" = NOTSET CLOSE_STDIN: Final = NOTSET
class TimeoutExpired(Exception): class TimeoutExpired(Exception):
pass pass
@ -674,9 +670,9 @@ class Pytester:
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self._request = request self._request = request
self._mod_collections: WeakKeyDictionary[ self._mod_collections: WeakKeyDictionary[Collector, list[Item | Collector]] = (
Collector, List[Union[Item, Collector]] WeakKeyDictionary()
] = WeakKeyDictionary() )
if request.function: if request.function:
name: str = request.function.__name__ name: str = request.function.__name__
else: else:
@ -687,7 +683,7 @@ class Pytester:
#: :py:meth:`runpytest`. Initially this is an empty list but plugins can #: :py:meth:`runpytest`. Initially this is an empty list but plugins can
#: be added to the list. The type of items to add to the list depends on #: be added to the list. The type of items to add to the list depends on
#: the method using them so refer to them for details. #: the method using them so refer to them for details.
self.plugins: List[Union[str, _PluggyPlugin]] = [] self.plugins: list[str | _PluggyPlugin] = []
self._sys_path_snapshot = SysPathsSnapshot() self._sys_path_snapshot = SysPathsSnapshot()
self._sys_modules_snapshot = self.__take_sys_modules_snapshot() self._sys_modules_snapshot = self.__take_sys_modules_snapshot()
self._request.addfinalizer(self._finalize) self._request.addfinalizer(self._finalize)
@ -755,8 +751,8 @@ class Pytester:
def _makefile( def _makefile(
self, self,
ext: str, ext: str,
lines: Sequence[Union[Any, bytes]], lines: Sequence[Any | bytes],
files: Dict[str, str], files: dict[str, str],
encoding: str = "utf-8", encoding: str = "utf-8",
) -> Path: ) -> Path:
items = list(files.items()) items = list(files.items())
@ -769,7 +765,7 @@ class Pytester:
f"pytester.makefile expects a file extension, try .{ext} instead of {ext}" f"pytester.makefile expects a file extension, try .{ext} instead of {ext}"
) )
def to_text(s: Union[Any, bytes]) -> str: def to_text(s: Any | bytes) -> str:
return s.decode(encoding) if isinstance(s, bytes) else str(s) return s.decode(encoding) if isinstance(s, bytes) else str(s)
if lines: if lines:
@ -889,9 +885,7 @@ class Pytester:
""" """
return self._makefile(".txt", args, kwargs) return self._makefile(".txt", args, kwargs)
def syspathinsert( def syspathinsert(self, path: str | os.PathLike[str] | None = None) -> None:
self, path: Optional[Union[str, "os.PathLike[str]"]] = None
) -> None:
"""Prepend a directory to sys.path, defaults to :attr:`path`. """Prepend a directory to sys.path, defaults to :attr:`path`.
This is undone automatically when this object dies at the end of each This is undone automatically when this object dies at the end of each
@ -905,19 +899,20 @@ class Pytester:
self._monkeypatch.syspath_prepend(str(path)) self._monkeypatch.syspath_prepend(str(path))
def mkdir(self, name: Union[str, "os.PathLike[str]"]) -> Path: def mkdir(self, name: str | os.PathLike[str]) -> Path:
"""Create a new (sub)directory. """Create a new (sub)directory.
:param name: :param name:
The name of the directory, relative to the pytester path. The name of the directory, relative to the pytester path.
:returns: :returns:
The created directory. The created directory.
:rtype: pathlib.Path
""" """
p = self.path / name p = self.path / name
p.mkdir() p.mkdir()
return p return p
def mkpydir(self, name: Union[str, "os.PathLike[str]"]) -> Path: def mkpydir(self, name: str | os.PathLike[str]) -> Path:
"""Create a new python package. """Create a new python package.
This creates a (sub)directory with an empty ``__init__.py`` file so it This creates a (sub)directory with an empty ``__init__.py`` file so it
@ -928,13 +923,14 @@ class Pytester:
p.joinpath("__init__.py").touch() p.joinpath("__init__.py").touch()
return p return p
def copy_example(self, name: Optional[str] = None) -> Path: def copy_example(self, name: str | None = None) -> Path:
"""Copy file from project's directory into the testdir. """Copy file from project's directory into the testdir.
:param name: :param name:
The name of the file to copy. The name of the file to copy.
:return: :return:
Path to the copied directory (inside ``self.path``). Path to the copied directory (inside ``self.path``).
:rtype: pathlib.Path
""" """
example_dir_ = self._request.config.getini("pytester_example_dir") example_dir_ = self._request.config.getini("pytester_example_dir")
if example_dir_ is None: if example_dir_ is None:
@ -973,9 +969,7 @@ class Pytester:
f'example "{example_path}" is not found as a file or directory' f'example "{example_path}" is not found as a file or directory'
) )
def getnode( def getnode(self, config: Config, arg: str | os.PathLike[str]) -> Collector | Item:
self, config: Config, arg: Union[str, "os.PathLike[str]"]
) -> Union[Collector, Item]:
"""Get the collection node of a file. """Get the collection node of a file.
:param config: :param config:
@ -994,9 +988,7 @@ class Pytester:
config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK) config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK)
return res return res
def getpathnode( def getpathnode(self, path: str | os.PathLike[str]) -> Collector | Item:
self, path: Union[str, "os.PathLike[str]"]
) -> Union[Collector, Item]:
"""Return the collection node of a file. """Return the collection node of a file.
This is like :py:meth:`getnode` but uses :py:meth:`parseconfigure` to This is like :py:meth:`getnode` but uses :py:meth:`parseconfigure` to
@ -1016,7 +1008,7 @@ class Pytester:
config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK) config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK)
return res return res
def genitems(self, colitems: Sequence[Union[Item, Collector]]) -> List[Item]: def genitems(self, colitems: Sequence[Item | Collector]) -> list[Item]:
"""Generate all test items from a collection node. """Generate all test items from a collection node.
This recurses into the collection node and returns a list of all the This recurses into the collection node and returns a list of all the
@ -1028,7 +1020,7 @@ class Pytester:
The collected items. The collected items.
""" """
session = colitems[0].session session = colitems[0].session
result: List[Item] = [] result: list[Item] = []
for colitem in colitems: for colitem in colitems:
result.extend(session.genitems(colitem)) result.extend(session.genitems(colitem))
return result return result
@ -1062,7 +1054,7 @@ class Pytester:
values = [*list(cmdlineargs), p] values = [*list(cmdlineargs), p]
return self.inline_run(*values) return self.inline_run(*values)
def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]: def inline_genitems(self, *args) -> tuple[list[Item], HookRecorder]:
"""Run ``pytest.main(['--collect-only'])`` in-process. """Run ``pytest.main(['--collect-only'])`` in-process.
Runs the :py:func:`pytest.main` function to run all of pytest inside Runs the :py:func:`pytest.main` function to run all of pytest inside
@ -1075,7 +1067,7 @@ class Pytester:
def inline_run( def inline_run(
self, self,
*args: Union[str, "os.PathLike[str]"], *args: str | os.PathLike[str],
plugins=(), plugins=(),
no_reraise_ctrlc: bool = False, no_reraise_ctrlc: bool = False,
) -> HookRecorder: ) -> HookRecorder:
@ -1145,7 +1137,7 @@ class Pytester:
finalizer() finalizer()
def runpytest_inprocess( def runpytest_inprocess(
self, *args: Union[str, "os.PathLike[str]"], **kwargs: Any self, *args: str | os.PathLike[str], **kwargs: Any
) -> RunResult: ) -> RunResult:
"""Return result of running pytest in-process, providing a similar """Return result of running pytest in-process, providing a similar
interface to what self.runpytest() provides.""" interface to what self.runpytest() provides."""
@ -1188,9 +1180,7 @@ class Pytester:
res.reprec = reprec # type: ignore res.reprec = reprec # type: ignore
return res return res
def runpytest( def runpytest(self, *args: str | os.PathLike[str], **kwargs: Any) -> RunResult:
self, *args: Union[str, "os.PathLike[str]"], **kwargs: Any
) -> RunResult:
"""Run pytest inline or in a subprocess, depending on the command line """Run pytest inline or in a subprocess, depending on the command line
option "--runpytest" and return a :py:class:`~pytest.RunResult`.""" option "--runpytest" and return a :py:class:`~pytest.RunResult`."""
new_args = self._ensure_basetemp(args) new_args = self._ensure_basetemp(args)
@ -1201,8 +1191,8 @@ class Pytester:
raise RuntimeError(f"Unrecognized runpytest option: {self._method}") raise RuntimeError(f"Unrecognized runpytest option: {self._method}")
def _ensure_basetemp( def _ensure_basetemp(
self, args: Sequence[Union[str, "os.PathLike[str]"]] self, args: Sequence[str | os.PathLike[str]]
) -> List[Union[str, "os.PathLike[str]"]]: ) -> list[str | os.PathLike[str]]:
new_args = list(args) new_args = list(args)
for x in new_args: for x in new_args:
if str(x).startswith("--basetemp"): if str(x).startswith("--basetemp"):
@ -1211,7 +1201,7 @@ class Pytester:
new_args.append("--basetemp=%s" % self.path.parent.joinpath("basetemp")) new_args.append("--basetemp=%s" % self.path.parent.joinpath("basetemp"))
return new_args return new_args
def parseconfig(self, *args: Union[str, "os.PathLike[str]"]) -> Config: def parseconfig(self, *args: str | os.PathLike[str]) -> Config:
"""Return a new pytest :class:`pytest.Config` instance from given """Return a new pytest :class:`pytest.Config` instance from given
commandline args. commandline args.
@ -1235,7 +1225,7 @@ class Pytester:
self._request.addfinalizer(config._ensure_unconfigure) self._request.addfinalizer(config._ensure_unconfigure)
return config return config
def parseconfigure(self, *args: Union[str, "os.PathLike[str]"]) -> Config: def parseconfigure(self, *args: str | os.PathLike[str]) -> Config:
"""Return a new pytest configured Config instance. """Return a new pytest configured Config instance.
Returns a new :py:class:`pytest.Config` instance like Returns a new :py:class:`pytest.Config` instance like
@ -1247,7 +1237,7 @@ class Pytester:
return config return config
def getitem( def getitem(
self, source: Union[str, "os.PathLike[str]"], funcname: str = "test_func" self, source: str | os.PathLike[str], funcname: str = "test_func"
) -> Item: ) -> Item:
"""Return the test item for a test function. """Return the test item for a test function.
@ -1268,7 +1258,7 @@ class Pytester:
return item return item
assert 0, f"{funcname!r} item not found in module:\n{source}\nitems: {items}" assert 0, f"{funcname!r} item not found in module:\n{source}\nitems: {items}"
def getitems(self, source: Union[str, "os.PathLike[str]"]) -> List[Item]: def getitems(self, source: str | os.PathLike[str]) -> list[Item]:
"""Return all test items collected from the module. """Return all test items collected from the module.
Writes the source to a Python file and runs pytest's collection on Writes the source to a Python file and runs pytest's collection on
@ -1279,7 +1269,7 @@ class Pytester:
def getmodulecol( def getmodulecol(
self, self,
source: Union[str, "os.PathLike[str]"], source: str | os.PathLike[str],
configargs=(), configargs=(),
*, *,
withinit: bool = False, withinit: bool = False,
@ -1311,9 +1301,7 @@ class Pytester:
self.config = config = self.parseconfigure(path, *configargs) self.config = config = self.parseconfigure(path, *configargs)
return self.getnode(config, path) return self.getnode(config, path)
def collect_by_name( def collect_by_name(self, modcol: Collector, name: str) -> Item | Collector | None:
self, modcol: Collector, name: str
) -> Optional[Union[Item, Collector]]:
"""Return the collection node for name from the module collection. """Return the collection node for name from the module collection.
Searches a module collection node for a collection node matching the Searches a module collection node for a collection node matching the
@ -1331,10 +1319,10 @@ class Pytester:
def popen( def popen(
self, self,
cmdargs: Sequence[Union[str, "os.PathLike[str]"]], cmdargs: Sequence[str | os.PathLike[str]],
stdout: Union[int, TextIO] = subprocess.PIPE, stdout: int | TextIO = subprocess.PIPE,
stderr: Union[int, TextIO] = subprocess.PIPE, stderr: int | TextIO = subprocess.PIPE,
stdin: Union[NotSetType, bytes, IO[Any], int] = CLOSE_STDIN, stdin: NotSetType | bytes | IO[Any] | int = CLOSE_STDIN,
**kw, **kw,
): ):
"""Invoke :py:class:`subprocess.Popen`. """Invoke :py:class:`subprocess.Popen`.
@ -1369,9 +1357,9 @@ class Pytester:
def run( def run(
self, self,
*cmdargs: Union[str, "os.PathLike[str]"], *cmdargs: str | os.PathLike[str],
timeout: Optional[float] = None, timeout: float | None = None,
stdin: Union[NotSetType, bytes, IO[Any], int] = CLOSE_STDIN, stdin: NotSetType | bytes | IO[Any] | int = CLOSE_STDIN,
) -> RunResult: ) -> RunResult:
"""Run a command with arguments. """Run a command with arguments.
@ -1399,8 +1387,10 @@ class Pytester:
- Otherwise, it is passed through to :py:class:`subprocess.Popen`. - Otherwise, it is passed through to :py:class:`subprocess.Popen`.
For further information in this case, consult the document of the For further information in this case, consult the document of the
``stdin`` parameter in :py:class:`subprocess.Popen`. ``stdin`` parameter in :py:class:`subprocess.Popen`.
:type stdin: _pytest.compat.NotSetType | bytes | IO[Any] | int
:returns: :returns:
The result. The result.
""" """
__tracebackhide__ = True __tracebackhide__ = True
@ -1457,10 +1447,10 @@ class Pytester:
except UnicodeEncodeError: except UnicodeEncodeError:
print(f"couldn't print to {fp} because of encoding") print(f"couldn't print to {fp} because of encoding")
def _getpytestargs(self) -> Tuple[str, ...]: def _getpytestargs(self) -> tuple[str, ...]:
return sys.executable, "-mpytest" return sys.executable, "-mpytest"
def runpython(self, script: "os.PathLike[str]") -> RunResult: def runpython(self, script: os.PathLike[str]) -> RunResult:
"""Run a python script using sys.executable as interpreter.""" """Run a python script using sys.executable as interpreter."""
return self.run(sys.executable, script) return self.run(sys.executable, script)
@ -1469,7 +1459,7 @@ class Pytester:
return self.run(sys.executable, "-c", command) return self.run(sys.executable, "-c", command)
def runpytest_subprocess( def runpytest_subprocess(
self, *args: Union[str, "os.PathLike[str]"], timeout: Optional[float] = None self, *args: str | os.PathLike[str], timeout: float | None = None
) -> RunResult: ) -> RunResult:
"""Run pytest as a subprocess with given arguments. """Run pytest as a subprocess with given arguments.
@ -1496,9 +1486,7 @@ class Pytester:
args = self._getpytestargs() + args args = self._getpytestargs() + args
return self.run(*args, timeout=timeout) return self.run(*args, timeout=timeout)
def spawn_pytest( def spawn_pytest(self, string: str, expect_timeout: float = 10.0) -> pexpect.spawn:
self, string: str, expect_timeout: float = 10.0
) -> "pexpect.spawn":
"""Run pytest using pexpect. """Run pytest using pexpect.
This makes sure to use the right pytest and sets up the temporary This makes sure to use the right pytest and sets up the temporary
@ -1512,7 +1500,7 @@ class Pytester:
cmd = f"{invoke} --basetemp={basetemp} {string}" cmd = f"{invoke} --basetemp={basetemp} {string}"
return self.spawn(cmd, expect_timeout=expect_timeout) return self.spawn(cmd, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn": def spawn(self, cmd: str, expect_timeout: float = 10.0) -> pexpect.spawn:
"""Run a command using pexpect. """Run a command using pexpect.
The pexpect child is returned. The pexpect child is returned.
@ -1557,9 +1545,9 @@ class LineMatcher:
``text.splitlines()``. ``text.splitlines()``.
""" """
def __init__(self, lines: List[str]) -> None: def __init__(self, lines: list[str]) -> None:
self.lines = lines self.lines = lines
self._log_output: List[str] = [] self._log_output: list[str] = []
def __str__(self) -> str: def __str__(self) -> str:
"""Return the entire original text. """Return the entire original text.
@ -1569,7 +1557,7 @@ class LineMatcher:
""" """
return "\n".join(self.lines) return "\n".join(self.lines)
def _getlines(self, lines2: Union[str, Sequence[str], Source]) -> Sequence[str]: def _getlines(self, lines2: str | Sequence[str] | Source) -> Sequence[str]:
if isinstance(lines2, str): if isinstance(lines2, str):
lines2 = Source(lines2) lines2 = Source(lines2)
if isinstance(lines2, Source): if isinstance(lines2, Source):

View File

@ -4,21 +4,19 @@
# contain them itself, since it is imported by the `pytest` module, # contain them itself, since it is imported by the `pytest` module,
# hence cannot be subject to assertion rewriting, which requires a # hence cannot be subject to assertion rewriting, which requires a
# module to not be already imported. # module to not be already imported.
from typing import Dict from __future__ import annotations
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Tuple
from typing import Union
from _pytest.reports import CollectReport from _pytest.reports import CollectReport
from _pytest.reports import TestReport from _pytest.reports import TestReport
def assertoutcome( def assertoutcome(
outcomes: Tuple[ outcomes: tuple[
Sequence[TestReport], Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]], Sequence[CollectReport | TestReport],
Sequence[Union[CollectReport, TestReport]], Sequence[CollectReport | TestReport],
], ],
passed: int = 0, passed: int = 0,
skipped: int = 0, skipped: int = 0,
@ -37,15 +35,15 @@ def assertoutcome(
def assert_outcomes( def assert_outcomes(
outcomes: Dict[str, int], outcomes: dict[str, int],
passed: int = 0, passed: int = 0,
skipped: int = 0, skipped: int = 0,
failed: int = 0, failed: int = 0,
errors: int = 0, errors: int = 0,
xpassed: int = 0, xpassed: int = 0,
xfailed: int = 0, xfailed: int = 0,
warnings: Optional[int] = None, warnings: int | None = None,
deselected: Optional[int] = None, deselected: int | None = None,
) -> None: ) -> None:
"""Assert that the specified outcomes appear with the respective """Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run.""" numbers (0 means it didn't occur) in the text output from a test run."""

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Python test discovery, setup and run of test functions.""" """Python test discovery, setup and run of test functions."""
from __future__ import annotations
import abc import abc
from collections import Counter from collections import Counter
from collections import defaultdict from collections import defaultdict
@ -20,16 +22,11 @@ from typing import final
from typing import Generator from typing import Generator
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Sequence from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import warnings import warnings
import _pytest import _pytest
@ -113,7 +110,7 @@ def pytest_addoption(parser: Parser) -> None:
) )
def pytest_generate_tests(metafunc: "Metafunc") -> None: def pytest_generate_tests(metafunc: Metafunc) -> None:
for marker in metafunc.definition.iter_markers(name="parametrize"): for marker in metafunc.definition.iter_markers(name="parametrize"):
metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker) metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker)
@ -153,7 +150,7 @@ def async_warn_and_skip(nodeid: str) -> None:
@hookimpl(trylast=True) @hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]: def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
testfunction = pyfuncitem.obj testfunction = pyfuncitem.obj
if is_async_function(testfunction): if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid) async_warn_and_skip(pyfuncitem.nodeid)
@ -174,7 +171,7 @@ def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
def pytest_collect_directory( def pytest_collect_directory(
path: Path, parent: nodes.Collector path: Path, parent: nodes.Collector
) -> Optional[nodes.Collector]: ) -> nodes.Collector | None:
pkginit = path / "__init__.py" pkginit = path / "__init__.py"
try: try:
has_pkginit = pkginit.is_file() has_pkginit = pkginit.is_file()
@ -186,7 +183,7 @@ def pytest_collect_directory(
return None return None
def pytest_collect_file(file_path: Path, parent: nodes.Collector) -> Optional["Module"]: def pytest_collect_file(file_path: Path, parent: nodes.Collector) -> Module | None:
if file_path.suffix == ".py": if file_path.suffix == ".py":
if not parent.session.isinitpath(file_path): if not parent.session.isinitpath(file_path):
if not path_matches_patterns( if not path_matches_patterns(
@ -206,14 +203,14 @@ def path_matches_patterns(path: Path, patterns: Iterable[str]) -> bool:
return any(fnmatch_ex(pattern, path) for pattern in patterns) return any(fnmatch_ex(pattern, path) for pattern in patterns)
def pytest_pycollect_makemodule(module_path: Path, parent) -> "Module": def pytest_pycollect_makemodule(module_path: Path, parent) -> Module:
return Module.from_parent(parent, path=module_path) return Module.from_parent(parent, path=module_path)
@hookimpl(trylast=True) @hookimpl(trylast=True)
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: Union["Module", "Class"], name: str, obj: object collector: Module | Class, name: str, obj: object
) -> Union[None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]]: ) -> None | nodes.Item | nodes.Collector | list[nodes.Item | nodes.Collector]:
assert isinstance(collector, (Class, Module)), type(collector) assert isinstance(collector, (Class, Module)), type(collector)
# Nothing was collected elsewhere, let's do it here. # Nothing was collected elsewhere, let's do it here.
if safe_isclass(obj): if safe_isclass(obj):
@ -320,7 +317,7 @@ class PyobjMixin(nodes.Node):
parts.reverse() parts.reverse()
return ".".join(parts) return ".".join(parts)
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]: def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
# XXX caching? # XXX caching?
path, lineno = getfslineno(self.obj) path, lineno = getfslineno(self.obj)
modpath = self.getmodpath() modpath = self.getmodpath()
@ -390,7 +387,7 @@ class PyCollector(PyobjMixin, nodes.Collector, abc.ABC):
return True return True
return False return False
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
if not getattr(self.obj, "__test__", True): if not getattr(self.obj, "__test__", True):
return [] return []
@ -402,11 +399,11 @@ class PyCollector(PyobjMixin, nodes.Collector, abc.ABC):
# In each class, nodes should be definition ordered. # In each class, nodes should be definition ordered.
# __dict__ is definition ordered. # __dict__ is definition ordered.
seen: Set[str] = set() seen: set[str] = set()
dict_values: List[List[Union[nodes.Item, nodes.Collector]]] = [] dict_values: list[list[nodes.Item | nodes.Collector]] = []
ihook = self.ihook ihook = self.ihook
for dic in dicts: for dic in dicts:
values: List[Union[nodes.Item, nodes.Collector]] = [] values: list[nodes.Item | nodes.Collector] = []
# Note: seems like the dict can change during iteration - # Note: seems like the dict can change during iteration -
# be careful not to remove the list() without consideration. # be careful not to remove the list() without consideration.
for name, obj in list(dic.items()): for name, obj in list(dic.items()):
@ -433,7 +430,7 @@ class PyCollector(PyobjMixin, nodes.Collector, abc.ABC):
result.extend(values) result.extend(values)
return result return result
def _genfunctions(self, name: str, funcobj) -> Iterator["Function"]: def _genfunctions(self, name: str, funcobj) -> Iterator[Function]:
modulecol = self.getparent(Module) modulecol = self.getparent(Module)
assert modulecol is not None assert modulecol is not None
module = modulecol.obj module = modulecol.obj
@ -544,7 +541,7 @@ class Module(nodes.File, PyCollector):
def _getobj(self): def _getobj(self):
return importtestmodule(self.path, self.config) return importtestmodule(self.path, self.config)
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
self._register_setup_module_fixture() self._register_setup_module_fixture()
self._register_setup_function_fixture() self._register_setup_function_fixture()
self.session._fixturemanager.parsefactories(self) self.session._fixturemanager.parsefactories(self)
@ -638,13 +635,13 @@ class Package(nodes.Directory):
def __init__( def __init__(
self, self,
fspath: Optional[LEGACY_PATH], fspath: LEGACY_PATH | None,
parent: nodes.Collector, parent: nodes.Collector,
# NOTE: following args are unused: # NOTE: following args are unused:
config=None, config=None,
session=None, session=None,
nodeid=None, nodeid=None,
path: Optional[Path] = None, path: Path | None = None,
) -> None: ) -> None:
# NOTE: Could be just the following, but kept as-is for compat. # NOTE: Could be just the following, but kept as-is for compat.
# super().__init__(self, fspath, parent=parent) # super().__init__(self, fspath, parent=parent)
@ -676,13 +673,13 @@ class Package(nodes.Directory):
func = partial(_call_with_optional_argument, teardown_module, init_mod) func = partial(_call_with_optional_argument, teardown_module, init_mod)
self.addfinalizer(func) self.addfinalizer(func)
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
# Always collect __init__.py first. # Always collect __init__.py first.
def sort_key(entry: "os.DirEntry[str]") -> object: def sort_key(entry: os.DirEntry[str]) -> object:
return (entry.name != "__init__.py", entry.name) return (entry.name != "__init__.py", entry.name)
config = self.config config = self.config
col: Optional[nodes.Collector] col: nodes.Collector | None
cols: Sequence[nodes.Collector] cols: Sequence[nodes.Collector]
ihook = self.ihook ihook = self.ihook
for direntry in scandir(self.path, sort_key): for direntry in scandir(self.path, sort_key):
@ -716,12 +713,12 @@ def _call_with_optional_argument(func, arg) -> None:
func() func()
def _get_first_non_fixture_func(obj: object, names: Iterable[str]) -> Optional[object]: def _get_first_non_fixture_func(obj: object, names: Iterable[str]) -> object | None:
"""Return the attribute from the given object to be used as a setup/teardown """Return the attribute from the given object to be used as a setup/teardown
xunit-style function, but only if not marked as a fixture to avoid calling it twice. xunit-style function, but only if not marked as a fixture to avoid calling it twice.
""" """
for name in names: for name in names:
meth: Optional[object] = getattr(obj, name, None) meth: object | None = getattr(obj, name, None)
if meth is not None and fixtures.getfixturemarker(meth) is None: if meth is not None and fixtures.getfixturemarker(meth) is None:
return meth return meth
return None return None
@ -731,14 +728,14 @@ class Class(PyCollector):
"""Collector for test methods (and nested classes) in a Python class.""" """Collector for test methods (and nested classes) in a Python class."""
@classmethod @classmethod
def from_parent(cls, parent, *, name, obj=None, **kw) -> "Self": # type: ignore[override] def from_parent(cls, parent, *, name, obj=None, **kw) -> Self: # type: ignore[override]
"""The public constructor.""" """The public constructor."""
return super().from_parent(name=name, parent=parent, **kw) return super().from_parent(name=name, parent=parent, **kw)
def newinstance(self): def newinstance(self):
return self.obj() return self.obj()
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
if not safe_getattr(self.obj, "__test__", True): if not safe_getattr(self.obj, "__test__", True):
return [] return []
if hasinit(self.obj): if hasinit(self.obj):
@ -868,21 +865,21 @@ class IdMaker:
parametersets: Sequence[ParameterSet] parametersets: Sequence[ParameterSet]
# Optionally, a user-provided callable to make IDs for parameters in a # Optionally, a user-provided callable to make IDs for parameters in a
# ParameterSet. # ParameterSet.
idfn: Optional[Callable[[Any], Optional[object]]] idfn: Callable[[Any], object | None] | None
# Optionally, explicit IDs for ParameterSets by index. # Optionally, explicit IDs for ParameterSets by index.
ids: Optional[Sequence[Optional[object]]] ids: Sequence[object | None] | None
# Optionally, the pytest config. # Optionally, the pytest config.
# Used for controlling ASCII escaping, and for calling the # Used for controlling ASCII escaping, and for calling the
# :hook:`pytest_make_parametrize_id` hook. # :hook:`pytest_make_parametrize_id` hook.
config: Optional[Config] config: Config | None
# Optionally, the ID of the node being parametrized. # Optionally, the ID of the node being parametrized.
# Used only for clearer error messages. # Used only for clearer error messages.
nodeid: Optional[str] nodeid: str | None
# Optionally, the ID of the function being parametrized. # Optionally, the ID of the function being parametrized.
# Used only for clearer error messages. # Used only for clearer error messages.
func_name: Optional[str] func_name: str | None
def make_unique_parameterset_ids(self) -> List[str]: def make_unique_parameterset_ids(self) -> list[str]:
"""Make a unique identifier for each ParameterSet, that may be used to """Make a unique identifier for each ParameterSet, that may be used to
identify the parametrization in a node ID. identify the parametrization in a node ID.
@ -899,7 +896,7 @@ class IdMaker:
# Record the number of occurrences of each ID. # Record the number of occurrences of each ID.
id_counts = Counter(resolved_ids) id_counts = Counter(resolved_ids)
# Map the ID to its next suffix. # Map the ID to its next suffix.
id_suffixes: Dict[str, int] = defaultdict(int) id_suffixes: dict[str, int] = defaultdict(int)
# Suffix non-unique IDs to make them unique. # Suffix non-unique IDs to make them unique.
for index, id in enumerate(resolved_ids): for index, id in enumerate(resolved_ids):
if id_counts[id] > 1: if id_counts[id] > 1:
@ -946,9 +943,7 @@ class IdMaker:
return idval return idval
return self._idval_from_argname(argname, idx) return self._idval_from_argname(argname, idx)
def _idval_from_function( def _idval_from_function(self, val: object, argname: str, idx: int) -> str | None:
self, val: object, argname: str, idx: int
) -> Optional[str]:
"""Try to make an ID for a parameter in a ParameterSet using the """Try to make an ID for a parameter in a ParameterSet using the
user-provided id callable, if given.""" user-provided id callable, if given."""
if self.idfn is None: if self.idfn is None:
@ -964,17 +959,17 @@ class IdMaker:
return None return None
return self._idval_from_value(id) return self._idval_from_value(id)
def _idval_from_hook(self, val: object, argname: str) -> Optional[str]: def _idval_from_hook(self, val: object, argname: str) -> str | None:
"""Try to make an ID for a parameter in a ParameterSet by calling the """Try to make an ID for a parameter in a ParameterSet by calling the
:hook:`pytest_make_parametrize_id` hook.""" :hook:`pytest_make_parametrize_id` hook."""
if self.config: if self.config:
id: Optional[str] = self.config.hook.pytest_make_parametrize_id( id: str | None = self.config.hook.pytest_make_parametrize_id(
config=self.config, val=val, argname=argname config=self.config, val=val, argname=argname
) )
return id return id
return None return None
def _idval_from_value(self, val: object) -> Optional[str]: def _idval_from_value(self, val: object) -> str | None:
"""Try to make an ID for a parameter in a ParameterSet from its value, """Try to make an ID for a parameter in a ParameterSet from its value,
if the value type is supported.""" if the value type is supported."""
if isinstance(val, (str, bytes)): if isinstance(val, (str, bytes)):
@ -1032,15 +1027,15 @@ class CallSpec2:
# arg name -> arg value which will be passed to a fixture or pseudo-fixture # arg name -> arg value which will be passed to a fixture or pseudo-fixture
# of the same name. (indirect or direct parametrization respectively) # of the same name. (indirect or direct parametrization respectively)
params: Dict[str, object] = dataclasses.field(default_factory=dict) params: dict[str, object] = dataclasses.field(default_factory=dict)
# arg name -> arg index. # arg name -> arg index.
indices: Dict[str, int] = dataclasses.field(default_factory=dict) indices: dict[str, int] = dataclasses.field(default_factory=dict)
# Used for sorting parametrized resources. # Used for sorting parametrized resources.
_arg2scope: Mapping[str, Scope] = dataclasses.field(default_factory=dict) _arg2scope: Mapping[str, Scope] = dataclasses.field(default_factory=dict)
# Parts which will be added to the item's name in `[..]` separated by "-". # Parts which will be added to the item's name in `[..]` separated by "-".
_idlist: Sequence[str] = dataclasses.field(default_factory=tuple) _idlist: Sequence[str] = dataclasses.field(default_factory=tuple)
# Marks which will be applied to the item. # Marks which will be applied to the item.
marks: List[Mark] = dataclasses.field(default_factory=list) marks: list[Mark] = dataclasses.field(default_factory=list)
def setmulti( def setmulti(
self, self,
@ -1048,10 +1043,10 @@ class CallSpec2:
argnames: Iterable[str], argnames: Iterable[str],
valset: Iterable[object], valset: Iterable[object],
id: str, id: str,
marks: Iterable[Union[Mark, MarkDecorator]], marks: Iterable[Mark | MarkDecorator],
scope: Scope, scope: Scope,
param_index: int, param_index: int,
) -> "CallSpec2": ) -> CallSpec2:
params = self.params.copy() params = self.params.copy()
indices = self.indices.copy() indices = self.indices.copy()
arg2scope = dict(self._arg2scope) arg2scope = dict(self._arg2scope)
@ -1099,7 +1094,7 @@ class Metafunc:
def __init__( def __init__(
self, self,
definition: "FunctionDefinition", definition: FunctionDefinition,
fixtureinfo: fixtures.FuncFixtureInfo, fixtureinfo: fixtures.FuncFixtureInfo,
config: Config, config: Config,
cls=None, cls=None,
@ -1130,19 +1125,17 @@ class Metafunc:
self._arg2fixturedefs = fixtureinfo.name2fixturedefs self._arg2fixturedefs = fixtureinfo.name2fixturedefs
# Result of parametrize(). # Result of parametrize().
self._calls: List[CallSpec2] = [] self._calls: list[CallSpec2] = []
def parametrize( def parametrize(
self, self,
argnames: Union[str, Sequence[str]], argnames: str | Sequence[str],
argvalues: Iterable[Union[ParameterSet, Sequence[object], object]], argvalues: Iterable[ParameterSet | Sequence[object] | object],
indirect: Union[bool, Sequence[str]] = False, indirect: bool | Sequence[str] = False,
ids: Optional[ ids: Iterable[object | None] | Callable[[Any], object | None] | None = None,
Union[Iterable[Optional[object]], Callable[[Any], Optional[object]]] scope: _ScopeName | None = None,
] = None,
scope: Optional[_ScopeName] = None,
*, *,
_param_mark: Optional[Mark] = None, _param_mark: Mark | None = None,
) -> None: ) -> None:
"""Add new invocations to the underlying test function using the list """Add new invocations to the underlying test function using the list
of argvalues for the given argnames. Parametrization is performed of argvalues for the given argnames. Parametrization is performed
@ -1171,7 +1164,7 @@ class Metafunc:
If N argnames were specified, argvalues must be a list of If N argnames were specified, argvalues must be a list of
N-tuples, where each tuple-element specifies a value for its N-tuples, where each tuple-element specifies a value for its
respective argname. respective argname.
:type argvalues: Iterable[_pytest.mark.structures.ParameterSet | Sequence[object] | object]
:param indirect: :param indirect:
A list of arguments' names (subset of argnames) or a boolean. A list of arguments' names (subset of argnames) or a boolean.
If True the list contains all names from the argnames. Each If True the list contains all names from the argnames. Each
@ -1271,7 +1264,7 @@ class Metafunc:
if node is None: if node is None:
name2pseudofixturedef = None name2pseudofixturedef = None
else: else:
default: Dict[str, FixtureDef[Any]] = {} default: dict[str, FixtureDef[Any]] = {}
name2pseudofixturedef = node.stash.setdefault( name2pseudofixturedef = node.stash.setdefault(
name2pseudofixturedef_key, default name2pseudofixturedef_key, default
) )
@ -1318,12 +1311,10 @@ class Metafunc:
def _resolve_parameter_set_ids( def _resolve_parameter_set_ids(
self, self,
argnames: Sequence[str], argnames: Sequence[str],
ids: Optional[ ids: Iterable[object | None] | Callable[[Any], object | None] | None,
Union[Iterable[Optional[object]], Callable[[Any], Optional[object]]]
],
parametersets: Sequence[ParameterSet], parametersets: Sequence[ParameterSet],
nodeid: str, nodeid: str,
) -> List[str]: ) -> list[str]:
"""Resolve the actual ids for the given parameter sets. """Resolve the actual ids for the given parameter sets.
:param argnames: :param argnames:
@ -1361,10 +1352,10 @@ class Metafunc:
def _validate_ids( def _validate_ids(
self, self,
ids: Iterable[Optional[object]], ids: Iterable[object | None],
parametersets: Sequence[ParameterSet], parametersets: Sequence[ParameterSet],
func_name: str, func_name: str,
) -> List[Optional[object]]: ) -> list[object | None]:
try: try:
num_ids = len(ids) # type: ignore[arg-type] num_ids = len(ids) # type: ignore[arg-type]
except TypeError: except TypeError:
@ -1384,8 +1375,8 @@ class Metafunc:
def _resolve_args_directness( def _resolve_args_directness(
self, self,
argnames: Sequence[str], argnames: Sequence[str],
indirect: Union[bool, Sequence[str]], indirect: bool | Sequence[str],
) -> Dict[str, Literal["indirect", "direct"]]: ) -> dict[str, Literal["indirect", "direct"]]:
"""Resolve if each parametrized argument must be considered an indirect """Resolve if each parametrized argument must be considered an indirect
parameter to a fixture of the same name, or a direct parameter to the parameter to a fixture of the same name, or a direct parameter to the
parametrized function, based on the ``indirect`` parameter of the parametrized function, based on the ``indirect`` parameter of the
@ -1398,7 +1389,7 @@ class Metafunc:
:returns :returns
A dict mapping each arg name to either "indirect" or "direct". A dict mapping each arg name to either "indirect" or "direct".
""" """
arg_directness: Dict[str, Literal["indirect", "direct"]] arg_directness: dict[str, Literal["indirect", "direct"]]
if isinstance(indirect, bool): if isinstance(indirect, bool):
arg_directness = dict.fromkeys( arg_directness = dict.fromkeys(
argnames, "indirect" if indirect else "direct" argnames, "indirect" if indirect else "direct"
@ -1423,7 +1414,7 @@ class Metafunc:
def _validate_if_using_arg_names( def _validate_if_using_arg_names(
self, self,
argnames: Sequence[str], argnames: Sequence[str],
indirect: Union[bool, Sequence[str]], indirect: bool | Sequence[str],
) -> None: ) -> None:
"""Check if all argnames are being used, by default values, or directly/indirectly. """Check if all argnames are being used, by default values, or directly/indirectly.
@ -1454,7 +1445,7 @@ class Metafunc:
def _find_parametrized_scope( def _find_parametrized_scope(
argnames: Sequence[str], argnames: Sequence[str],
arg2fixturedefs: Mapping[str, Sequence[fixtures.FixtureDef[object]]], arg2fixturedefs: Mapping[str, Sequence[fixtures.FixtureDef[object]]],
indirect: Union[bool, Sequence[str]], indirect: bool | Sequence[str],
) -> Scope: ) -> Scope:
"""Find the most appropriate scope for a parametrized call based on its arguments. """Find the most appropriate scope for a parametrized call based on its arguments.
@ -1483,7 +1474,7 @@ def _find_parametrized_scope(
return Scope.Function return Scope.Function
def _ascii_escaped_by_config(val: Union[str, bytes], config: Optional[Config]) -> str: def _ascii_escaped_by_config(val: str | bytes, config: Config | None) -> str:
if config is None: if config is None:
escape_option = False escape_option = False
else: else:
@ -1532,13 +1523,13 @@ class Function(PyobjMixin, nodes.Item):
self, self,
name: str, name: str,
parent, parent,
config: Optional[Config] = None, config: Config | None = None,
callspec: Optional[CallSpec2] = None, callspec: CallSpec2 | None = None,
callobj=NOTSET, callobj=NOTSET,
keywords: Optional[Mapping[str, Any]] = None, keywords: Mapping[str, Any] | None = None,
session: Optional[Session] = None, session: Session | None = None,
fixtureinfo: Optional[FuncFixtureInfo] = None, fixtureinfo: FuncFixtureInfo | None = None,
originalname: Optional[str] = None, originalname: str | None = None,
) -> None: ) -> None:
super().__init__(name, parent, config=config, session=session) super().__init__(name, parent, config=config, session=session)
@ -1581,12 +1572,12 @@ class Function(PyobjMixin, nodes.Item):
# todo: determine sound type limitations # todo: determine sound type limitations
@classmethod @classmethod
def from_parent(cls, parent, **kw) -> "Self": def from_parent(cls, parent, **kw) -> Self:
"""The public constructor.""" """The public constructor."""
return super().from_parent(parent=parent, **kw) return super().from_parent(parent=parent, **kw)
def _initrequest(self) -> None: def _initrequest(self) -> None:
self.funcargs: Dict[str, object] = {} self.funcargs: dict[str, object] = {}
self._request = fixtures.TopRequest(self, _ispytest=True) self._request = fixtures.TopRequest(self, _ispytest=True)
@property @property
@ -1667,7 +1658,7 @@ class Function(PyobjMixin, nodes.Item):
def repr_failure( # type: ignore[override] def repr_failure( # type: ignore[override]
self, self,
excinfo: ExceptionInfo[BaseException], excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]: ) -> str | TerminalRepr:
style = self.config.getoption("tbstyle", "auto") style = self.config.getoption("tbstyle", "auto")
if style == "auto": if style == "auto":
style = "long" style = "long"

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
from collections.abc import Collection from collections.abc import Collection
from collections.abc import Sized from collections.abc import Sized
from decimal import Decimal from decimal import Decimal
@ -11,9 +13,7 @@ from typing import Callable
from typing import cast from typing import cast
from typing import ContextManager from typing import ContextManager
from typing import final from typing import final
from typing import List
from typing import Mapping from typing import Mapping
from typing import Optional
from typing import overload from typing import overload
from typing import Pattern from typing import Pattern
from typing import Sequence from typing import Sequence
@ -21,7 +21,6 @@ from typing import Tuple
from typing import Type from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union
import _pytest._code import _pytest._code
from _pytest.outcomes import fail from _pytest.outcomes import fail
@ -33,12 +32,12 @@ if TYPE_CHECKING:
def _compare_approx( def _compare_approx(
full_object: object, full_object: object,
message_data: Sequence[Tuple[str, str, str]], message_data: Sequence[tuple[str, str, str]],
number_of_elements: int, number_of_elements: int,
different_ids: Sequence[object], different_ids: Sequence[object],
max_abs_diff: float, max_abs_diff: float,
max_rel_diff: float, max_rel_diff: float,
) -> List[str]: ) -> list[str]:
message_list = list(message_data) message_list = list(message_data)
message_list.insert(0, ("Index", "Obtained", "Expected")) message_list.insert(0, ("Index", "Obtained", "Expected"))
max_sizes = [0, 0, 0] max_sizes = [0, 0, 0]
@ -79,7 +78,7 @@ class ApproxBase:
def __repr__(self) -> str: def __repr__(self) -> str:
raise NotImplementedError raise NotImplementedError
def _repr_compare(self, other_side: Any) -> List[str]: def _repr_compare(self, other_side: Any) -> list[str]:
return [ return [
"comparison failed", "comparison failed",
f"Obtained: {other_side}", f"Obtained: {other_side}",
@ -103,7 +102,7 @@ class ApproxBase:
def __ne__(self, actual) -> bool: def __ne__(self, actual) -> bool:
return not (actual == self) return not (actual == self)
def _approx_scalar(self, x) -> "ApproxScalar": def _approx_scalar(self, x) -> ApproxScalar:
if isinstance(x, Decimal): if isinstance(x, Decimal):
return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
@ -142,12 +141,12 @@ class ApproxNumpy(ApproxBase):
) )
return f"approx({list_scalars!r})" return f"approx({list_scalars!r})"
def _repr_compare(self, other_side: Union["ndarray", List[Any]]) -> List[str]: def _repr_compare(self, other_side: ndarray | list[Any]) -> list[str]:
import itertools import itertools
import math import math
def get_value_from_nested_list( def get_value_from_nested_list(
nested_list: List[Any], nd_index: Tuple[Any, ...] nested_list: list[Any], nd_index: tuple[Any, ...]
) -> Any: ) -> Any:
""" """
Helper function to get the value out of a nested list, given an n-dimensional index. Helper function to get the value out of a nested list, given an n-dimensional index.
@ -244,7 +243,7 @@ class ApproxMapping(ApproxBase):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"approx({({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})" return f"approx({({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})"
def _repr_compare(self, other_side: Mapping[object, float]) -> List[str]: def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
import math import math
approx_side_as_map = { approx_side_as_map = {
@ -319,7 +318,7 @@ class ApproxSequenceLike(ApproxBase):
seq_type = list seq_type = list
return f"approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})" return f"approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})"
def _repr_compare(self, other_side: Sequence[float]) -> List[str]: def _repr_compare(self, other_side: Sequence[float]) -> list[str]:
import math import math
if len(self.expected) != len(other_side): if len(self.expected) != len(other_side):
@ -384,8 +383,8 @@ class ApproxScalar(ApproxBase):
# Using Real should be better than this Union, but not possible yet: # Using Real should be better than this Union, but not possible yet:
# https://github.com/python/typeshed/pull/3108 # https://github.com/python/typeshed/pull/3108
DEFAULT_ABSOLUTE_TOLERANCE: Union[float, Decimal] = 1e-12 DEFAULT_ABSOLUTE_TOLERANCE: float | Decimal = 1e-12
DEFAULT_RELATIVE_TOLERANCE: Union[float, Decimal] = 1e-6 DEFAULT_RELATIVE_TOLERANCE: float | Decimal = 1e-6
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return a string communicating both the expected value and the """Return a string communicating both the expected value and the
@ -715,7 +714,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
__tracebackhide__ = True __tracebackhide__ = True
if isinstance(expected, Decimal): if isinstance(expected, Decimal):
cls: Type[ApproxBase] = ApproxDecimal cls: type[ApproxBase] = ApproxDecimal
elif isinstance(expected, Mapping): elif isinstance(expected, Mapping):
cls = ApproxMapping cls = ApproxMapping
elif _is_numpy_array(expected): elif _is_numpy_array(expected):
@ -744,7 +743,7 @@ def _is_numpy_array(obj: object) -> bool:
return _as_numpy_array(obj) is not None return _as_numpy_array(obj) is not None
def _as_numpy_array(obj: object) -> Optional["ndarray"]: def _as_numpy_array(obj: object) -> ndarray | None:
""" """
Return an ndarray if the given object is implicitly convertible to ndarray, Return an ndarray if the given object is implicitly convertible to ndarray,
and numpy is already imported, otherwise None. and numpy is already imported, otherwise None.
@ -770,15 +769,15 @@ E = TypeVar("E", bound=BaseException)
@overload @overload
def raises( def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: type[E] | tuple[type[E], ...],
*, *,
match: Optional[Union[str, Pattern[str]]] = ..., match: str | Pattern[str] | None = ...,
) -> "RaisesContext[E]": ... ) -> RaisesContext[E]: ...
@overload @overload
def raises( def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: type[E] | tuple[type[E], ...],
func: Callable[..., Any], func: Callable[..., Any],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
@ -786,8 +785,8 @@ def raises(
def raises( def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]], *args: Any, **kwargs: Any expected_exception: type[E] | tuple[type[E], ...], *args: Any, **kwargs: Any
) -> Union["RaisesContext[E]", _pytest._code.ExceptionInfo[E]]: ) -> RaisesContext[E] | _pytest._code.ExceptionInfo[E]:
r"""Assert that a code block/function call raises an exception type, or one of its subclasses. r"""Assert that a code block/function call raises an exception type, or one of its subclasses.
:param expected_exception: :param expected_exception:
@ -935,7 +934,7 @@ def raises(
f"any special code to say 'this should never raise an exception'." f"any special code to say 'this should never raise an exception'."
) )
if isinstance(expected_exception, type): if isinstance(expected_exception, type):
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,) expected_exceptions: tuple[type[E], ...] = (expected_exception,)
else: else:
expected_exceptions = expected_exception expected_exceptions = expected_exception
for exc in expected_exceptions: for exc in expected_exceptions:
@ -947,7 +946,7 @@ def raises(
message = f"DID NOT RAISE {expected_exception}" message = f"DID NOT RAISE {expected_exception}"
if not args: if not args:
match: Optional[Union[str, Pattern[str]]] = kwargs.pop("match", None) match: str | Pattern[str] | None = kwargs.pop("match", None)
if kwargs: if kwargs:
msg = "Unexpected keyword arguments passed to pytest.raises: " msg = "Unexpected keyword arguments passed to pytest.raises: "
msg += ", ".join(sorted(kwargs)) msg += ", ".join(sorted(kwargs))
@ -973,14 +972,14 @@ raises.Exception = fail.Exception # type: ignore
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__( def __init__(
self, self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: type[E] | tuple[type[E], ...],
message: str, message: str,
match_expr: Optional[Union[str, Pattern[str]]] = None, match_expr: str | Pattern[str] | None = None,
) -> None: ) -> None:
self.expected_exception = expected_exception self.expected_exception = expected_exception
self.message = message self.message = message
self.match_expr = match_expr self.match_expr = match_expr
self.excinfo: Optional[_pytest._code.ExceptionInfo[E]] = None self.excinfo: _pytest._code.ExceptionInfo[E] | None = None
def __enter__(self) -> _pytest._code.ExceptionInfo[E]: def __enter__(self) -> _pytest._code.ExceptionInfo[E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later() self.excinfo = _pytest._code.ExceptionInfo.for_later()
@ -988,9 +987,9 @@ class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> bool: ) -> bool:
__tracebackhide__ = True __tracebackhide__ = True
if exc_type is None: if exc_type is None:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import sys import sys
import pytest import pytest

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Record warnings during test function execution.""" """Record warnings during test function execution."""
from __future__ import annotations
from pprint import pformat from pprint import pformat
import re import re
from types import TracebackType from types import TracebackType
@ -9,14 +11,15 @@ from typing import Callable
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import Iterator from typing import Iterator
from typing import List
from typing import Optional
from typing import overload from typing import overload
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import TYPE_CHECKING
from typing import Type
from typing import TypeVar from typing import TypeVar
from typing import Union
if TYPE_CHECKING:
from typing_extensions import Self
import warnings import warnings
from _pytest.deprecated import check_ispytest from _pytest.deprecated import check_ispytest
@ -29,7 +32,7 @@ T = TypeVar("T")
@fixture @fixture
def recwarn() -> Generator["WarningsRecorder", None, None]: def recwarn() -> Generator[WarningsRecorder, None, None]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions. """Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See https://docs.pytest.org/en/latest/how-to/capture-warnings.html for information See https://docs.pytest.org/en/latest/how-to/capture-warnings.html for information
@ -42,9 +45,7 @@ def recwarn() -> Generator["WarningsRecorder", None, None]:
@overload @overload
def deprecated_call( def deprecated_call(*, match: str | Pattern[str] | None = ...) -> WarningsRecorder: ...
*, match: Optional[Union[str, Pattern[str]]] = ...
) -> "WarningsRecorder": ...
@overload @overload
@ -52,8 +53,8 @@ def deprecated_call(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: ...
def deprecated_call( def deprecated_call(
func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any func: Callable[..., Any] | None = None, *args: Any, **kwargs: Any
) -> Union["WarningsRecorder", Any]: ) -> WarningsRecorder | Any:
"""Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning`` or ``FutureWarning``. """Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning`` or ``FutureWarning``.
This function can be used as a context manager:: This function can be used as a context manager::
@ -87,15 +88,15 @@ def deprecated_call(
@overload @overload
def warns( def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = ..., expected_warning: type[Warning] | tuple[type[Warning], ...] = ...,
*, *,
match: Optional[Union[str, Pattern[str]]] = ..., match: str | Pattern[str] | None = ...,
) -> "WarningsChecker": ... ) -> WarningsChecker: ...
@overload @overload
def warns( def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]], expected_warning: type[Warning] | tuple[type[Warning], ...],
func: Callable[..., T], func: Callable[..., T],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
@ -103,11 +104,11 @@ def warns(
def warns( def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning, expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
*args: Any, *args: Any,
match: Optional[Union[str, Pattern[str]]] = None, match: str | Pattern[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> Union["WarningsChecker", Any]: ) -> WarningsChecker | Any:
r"""Assert that code raises a particular class of warning. r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or tuple Specifically, the parameter ``expected_warning`` can be a warning class or tuple
@ -183,18 +184,18 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
check_ispytest(_ispytest) check_ispytest(_ispytest)
super().__init__(record=True) super().__init__(record=True)
self._entered = False self._entered = False
self._list: List[warnings.WarningMessage] = [] self._list: list[warnings.WarningMessage] = []
@property @property
def list(self) -> List["warnings.WarningMessage"]: def list(self) -> list[warnings.WarningMessage]:
"""The list of recorded warnings.""" """The list of recorded warnings."""
return self._list return self._list
def __getitem__(self, i: int) -> "warnings.WarningMessage": def __getitem__(self, i: int) -> warnings.WarningMessage:
"""Get a recorded warning by index.""" """Get a recorded warning by index."""
return self._list[i] return self._list[i]
def __iter__(self) -> Iterator["warnings.WarningMessage"]: def __iter__(self) -> Iterator[warnings.WarningMessage]:
"""Iterate through the recorded warnings.""" """Iterate through the recorded warnings."""
return iter(self._list) return iter(self._list)
@ -202,12 +203,12 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
"""The number of recorded warnings.""" """The number of recorded warnings."""
return len(self._list) return len(self._list)
def pop(self, cls: Type[Warning] = Warning) -> "warnings.WarningMessage": def pop(self, cls: type[Warning] = Warning) -> warnings.WarningMessage:
"""Pop the first recorded warning which is an instance of ``cls``, """Pop the first recorded warning which is an instance of ``cls``,
but not an instance of a child class of any other match. but not an instance of a child class of any other match.
Raises ``AssertionError`` if there is no match. Raises ``AssertionError`` if there is no match.
""" """
best_idx: Optional[int] = None best_idx: int | None = None
for i, w in enumerate(self._list): for i, w in enumerate(self._list):
if w.category == cls: if w.category == cls:
return self._list.pop(i) # exact match, stop looking return self._list.pop(i) # exact match, stop looking
@ -225,9 +226,7 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
"""Clear the list of recorded warnings.""" """Clear the list of recorded warnings."""
self._list[:] = [] self._list[:] = []
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__ def __enter__(self) -> Self:
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
if self._entered: if self._entered:
__tracebackhide__ = True __tracebackhide__ = True
raise RuntimeError(f"Cannot enter {self!r} twice") raise RuntimeError(f"Cannot enter {self!r} twice")
@ -240,9 +239,9 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
if not self._entered: if not self._entered:
__tracebackhide__ = True __tracebackhide__ = True
@ -259,8 +258,8 @@ class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
class WarningsChecker(WarningsRecorder): class WarningsChecker(WarningsRecorder):
def __init__( def __init__(
self, self,
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning, expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
match_expr: Optional[Union[str, Pattern[str]]] = None, match_expr: str | Pattern[str] | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -291,9 +290,9 @@ class WarningsChecker(WarningsRecorder):
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)

View File

@ -1,24 +1,19 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses import dataclasses
from io import StringIO from io import StringIO
import os import os
from pprint import pprint from pprint import pprint
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import Dict
from typing import final from typing import final
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import NoReturn from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo from _pytest._code.code import ExceptionInfo
@ -39,6 +34,8 @@ from _pytest.outcomes import skip
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from _pytest.runner import CallInfo from _pytest.runner import CallInfo
@ -54,16 +51,13 @@ def getworkerinfoline(node):
return s return s
_R = TypeVar("_R", bound="BaseReport")
class BaseReport: class BaseReport:
when: Optional[str] when: str | None
location: Optional[Tuple[str, Optional[int], str]] location: tuple[str, int | None, str] | None
longrepr: Union[ longrepr: (
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
] )
sections: List[Tuple[str, str]] sections: list[tuple[str, str]]
nodeid: str nodeid: str
outcome: Literal["passed", "failed", "skipped"] outcome: Literal["passed", "failed", "skipped"]
@ -94,7 +88,7 @@ class BaseReport:
s = "<unprintable longrepr>" s = "<unprintable longrepr>"
out.line(s) out.line(s)
def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]: def get_sections(self, prefix: str) -> Iterator[tuple[str, str]]:
for name, content in self.sections: for name, content in self.sections:
if name.startswith(prefix): if name.startswith(prefix):
yield prefix, content yield prefix, content
@ -176,7 +170,7 @@ class BaseReport:
return True return True
@property @property
def head_line(self) -> Optional[str]: def head_line(self) -> str | None:
"""**Experimental** The head line shown with longrepr output for this """**Experimental** The head line shown with longrepr output for this
report, more commonly during traceback representation during report, more commonly during traceback representation during
failures:: failures::
@ -202,7 +196,7 @@ class BaseReport:
) )
return verbose return verbose
def _to_json(self) -> Dict[str, Any]: def _to_json(self) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries, """Return the contents of this report as a dict of builtin entries,
suitable for serialization. suitable for serialization.
@ -213,7 +207,7 @@ class BaseReport:
return _report_to_json(self) return _report_to_json(self)
@classmethod @classmethod
def _from_json(cls: Type[_R], reportdict: Dict[str, object]) -> _R: def _from_json(cls, reportdict: dict[str, object]) -> Self:
"""Create either a TestReport or CollectReport, depending on the calling class. """Create either a TestReport or CollectReport, depending on the calling class.
It is the callers responsibility to know which class to pass here. It is the callers responsibility to know which class to pass here.
@ -227,7 +221,7 @@ class BaseReport:
def _report_unserialization_failure( def _report_unserialization_failure(
type_name: str, report_class: Type[BaseReport], reportdict type_name: str, report_class: type[BaseReport], reportdict
) -> NoReturn: ) -> NoReturn:
url = "https://github.com/pytest-dev/pytest/issues" url = "https://github.com/pytest-dev/pytest/issues"
stream = StringIO() stream = StringIO()
@ -256,18 +250,20 @@ class TestReport(BaseReport):
def __init__( def __init__(
self, self,
nodeid: str, nodeid: str,
location: Tuple[str, Optional[int], str], location: tuple[str, int | None, str],
keywords: Mapping[str, Any], keywords: Mapping[str, Any],
outcome: Literal["passed", "failed", "skipped"], outcome: Literal["passed", "failed", "skipped"],
longrepr: Union[ longrepr: None
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr | ExceptionInfo[BaseException]
], | tuple[str, int, str]
| str
| TerminalRepr,
when: Literal["setup", "call", "teardown"], when: Literal["setup", "call", "teardown"],
sections: Iterable[Tuple[str, str]] = (), sections: Iterable[tuple[str, str]] = (),
duration: float = 0, duration: float = 0,
start: float = 0, start: float = 0,
stop: float = 0, stop: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None, user_properties: Iterable[tuple[str, object]] | None = None,
**extra, **extra,
) -> None: ) -> None:
#: Normalized collection nodeid. #: Normalized collection nodeid.
@ -278,7 +274,7 @@ class TestReport(BaseReport):
#: collected one e.g. if a method is inherited from a different module. #: collected one e.g. if a method is inherited from a different module.
#: The filesystempath may be relative to ``config.rootdir``. #: The filesystempath may be relative to ``config.rootdir``.
#: The line number is 0-based. #: The line number is 0-based.
self.location: Tuple[str, Optional[int], str] = location self.location: tuple[str, int | None, str] = location
#: A name -> value dictionary containing all keywords and #: A name -> value dictionary containing all keywords and
#: markers associated with a test invocation. #: markers associated with a test invocation.
@ -317,7 +313,7 @@ class TestReport(BaseReport):
return f"<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>" return f"<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>"
@classmethod @classmethod
def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport": def from_item_and_call(cls, item: Item, call: CallInfo[None]) -> TestReport:
"""Create and fill a TestReport with standard item and call info. """Create and fill a TestReport with standard item and call info.
:param item: The item. :param item: The item.
@ -334,13 +330,13 @@ class TestReport(BaseReport):
sections = [] sections = []
if not call.excinfo: if not call.excinfo:
outcome: Literal["passed", "failed", "skipped"] = "passed" outcome: Literal["passed", "failed", "skipped"] = "passed"
longrepr: Union[ longrepr: (
None, None
ExceptionInfo[BaseException], | ExceptionInfo[BaseException]
Tuple[str, int, str], | tuple[str, int, str]
str, | str
TerminalRepr, | TerminalRepr
] = None ) = None
else: else:
if not isinstance(excinfo, ExceptionInfo): if not isinstance(excinfo, ExceptionInfo):
outcome = "failed" outcome = "failed"
@ -394,12 +390,14 @@ class CollectReport(BaseReport):
def __init__( def __init__(
self, self,
nodeid: str, nodeid: str,
outcome: "Literal['passed', 'failed', 'skipped']", outcome: Literal["passed", "failed", "skipped"],
longrepr: Union[ longrepr: None
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr | ExceptionInfo[BaseException]
], | tuple[str, int, str]
result: Optional[List[Union[Item, Collector]]], | str
sections: Iterable[Tuple[str, str]] = (), | TerminalRepr,
result: list[Item | Collector] | None,
sections: Iterable[tuple[str, str]] = (),
**extra, **extra,
) -> None: ) -> None:
#: Normalized collection nodeid. #: Normalized collection nodeid.
@ -425,7 +423,7 @@ class CollectReport(BaseReport):
@property @property
def location( # type:ignore[override] def location( # type:ignore[override]
self, self,
) -> Optional[Tuple[str, Optional[int], str]]: ) -> tuple[str, int | None, str] | None:
return (self.fspath, None, self.fspath) return (self.fspath, None, self.fspath)
def __repr__(self) -> str: def __repr__(self) -> str:
@ -441,8 +439,8 @@ class CollectErrorRepr(TerminalRepr):
def pytest_report_to_serializable( def pytest_report_to_serializable(
report: Union[CollectReport, TestReport], report: CollectReport | TestReport,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
if isinstance(report, (TestReport, CollectReport)): if isinstance(report, (TestReport, CollectReport)):
data = report._to_json() data = report._to_json()
data["$report_type"] = report.__class__.__name__ data["$report_type"] = report.__class__.__name__
@ -452,8 +450,8 @@ def pytest_report_to_serializable(
def pytest_report_from_serializable( def pytest_report_from_serializable(
data: Dict[str, Any], data: dict[str, Any],
) -> Optional[Union[CollectReport, TestReport]]: ) -> CollectReport | TestReport | None:
if "$report_type" in data: if "$report_type" in data:
if data["$report_type"] == "TestReport": if data["$report_type"] == "TestReport":
return TestReport._from_json(data) return TestReport._from_json(data)
@ -465,7 +463,7 @@ def pytest_report_from_serializable(
return None return None
def _report_to_json(report: BaseReport) -> Dict[str, Any]: def _report_to_json(report: BaseReport) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries, """Return the contents of this report as a dict of builtin entries,
suitable for serialization. suitable for serialization.
@ -473,8 +471,8 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
""" """
def serialize_repr_entry( def serialize_repr_entry(
entry: Union[ReprEntry, ReprEntryNative], entry: ReprEntry | ReprEntryNative,
) -> Dict[str, Any]: ) -> dict[str, Any]:
data = dataclasses.asdict(entry) data = dataclasses.asdict(entry)
for key, value in data.items(): for key, value in data.items():
if hasattr(value, "__dict__"): if hasattr(value, "__dict__"):
@ -482,7 +480,7 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
entry_data = {"type": type(entry).__name__, "data": data} entry_data = {"type": type(entry).__name__, "data": data}
return entry_data return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]: def serialize_repr_traceback(reprtraceback: ReprTraceback) -> dict[str, Any]:
result = dataclasses.asdict(reprtraceback) result = dataclasses.asdict(reprtraceback)
result["reprentries"] = [ result["reprentries"] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries serialize_repr_entry(x) for x in reprtraceback.reprentries
@ -490,18 +488,18 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
return result return result
def serialize_repr_crash( def serialize_repr_crash(
reprcrash: Optional[ReprFileLocation], reprcrash: ReprFileLocation | None,
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
if reprcrash is not None: if reprcrash is not None:
return dataclasses.asdict(reprcrash) return dataclasses.asdict(reprcrash)
else: else:
return None return None
def serialize_exception_longrepr(rep: BaseReport) -> Dict[str, Any]: def serialize_exception_longrepr(rep: BaseReport) -> dict[str, Any]:
assert rep.longrepr is not None assert rep.longrepr is not None
# TODO: Investigate whether the duck typing is really necessary here. # TODO: Investigate whether the duck typing is really necessary here.
longrepr = cast(ExceptionRepr, rep.longrepr) longrepr = cast(ExceptionRepr, rep.longrepr)
result: Dict[str, Any] = { result: dict[str, Any] = {
"reprcrash": serialize_repr_crash(longrepr.reprcrash), "reprcrash": serialize_repr_crash(longrepr.reprcrash),
"reprtraceback": serialize_repr_traceback(longrepr.reprtraceback), "reprtraceback": serialize_repr_traceback(longrepr.reprtraceback),
"sections": longrepr.sections, "sections": longrepr.sections,
@ -538,7 +536,7 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
return d return d
def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]: def _report_kwargs_from_json(reportdict: dict[str, Any]) -> dict[str, Any]:
"""Return **kwargs that can be used to construct a TestReport or """Return **kwargs that can be used to construct a TestReport or
CollectReport instance. CollectReport instance.
@ -559,7 +557,7 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
if data["reprlocals"]: if data["reprlocals"]:
reprlocals = ReprLocals(data["reprlocals"]["lines"]) reprlocals = ReprLocals(data["reprlocals"]["lines"])
reprentry: Union[ReprEntry, ReprEntryNative] = ReprEntry( reprentry: ReprEntry | ReprEntryNative = ReprEntry(
lines=data["lines"], lines=data["lines"],
reprfuncargs=reprfuncargs, reprfuncargs=reprfuncargs,
reprlocals=reprlocals, reprlocals=reprlocals,
@ -578,7 +576,7 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
] ]
return ReprTraceback(**repr_traceback_dict) return ReprTraceback(**repr_traceback_dict)
def deserialize_repr_crash(repr_crash_dict: Optional[Dict[str, Any]]): def deserialize_repr_crash(repr_crash_dict: dict[str, Any] | None):
if repr_crash_dict is not None: if repr_crash_dict is not None:
return ReprFileLocation(**repr_crash_dict) return ReprFileLocation(**repr_crash_dict)
else: else:
@ -605,8 +603,8 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
description, description,
) )
) )
exception_info: Union[ExceptionChainRepr, ReprExceptionInfo] = ( exception_info: ExceptionChainRepr | ReprExceptionInfo = ExceptionChainRepr(
ExceptionChainRepr(chain) chain
) )
else: else:
exception_info = ReprExceptionInfo( exception_info = ReprExceptionInfo(

View File

@ -1,23 +1,19 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Basic collect and runtest protocol implementations.""" """Basic collect and runtest protocol implementations."""
from __future__ import annotations
import bdb import bdb
import dataclasses import dataclasses
import os import os
import sys import sys
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import Dict
from typing import final from typing import final
from typing import Generic from typing import Generic
from typing import List
from typing import Literal from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union
from .reports import BaseReport from .reports import BaseReport
from .reports import CollectErrorRepr from .reports import CollectErrorRepr
@ -71,7 +67,7 @@ def pytest_addoption(parser: Parser) -> None:
) )
def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None: def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
durations = terminalreporter.config.option.durations durations = terminalreporter.config.option.durations
durations_min = terminalreporter.config.option.durations_min durations_min = terminalreporter.config.option.durations_min
verbose = terminalreporter.config.getvalue("verbose") verbose = terminalreporter.config.getvalue("verbose")
@ -102,15 +98,15 @@ def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None:
tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}") tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}")
def pytest_sessionstart(session: "Session") -> None: def pytest_sessionstart(session: Session) -> None:
session._setupstate = SetupState() session._setupstate = SetupState()
def pytest_sessionfinish(session: "Session") -> None: def pytest_sessionfinish(session: Session) -> None:
session._setupstate.teardown_exact(None) session._setupstate.teardown_exact(None)
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool: def pytest_runtest_protocol(item: Item, nextitem: Item | None) -> bool:
ihook = item.ihook ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location) ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem) runtestprotocol(item, nextitem=nextitem)
@ -119,8 +115,8 @@ def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
def runtestprotocol( def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None item: Item, log: bool = True, nextitem: Item | None = None
) -> List[TestReport]: ) -> list[TestReport]:
hasrequest = hasattr(item, "_request") hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined] if hasrequest and not item._request: # type: ignore[attr-defined]
# This only happens if the item is re-run, as is done by # This only happens if the item is re-run, as is done by
@ -183,14 +179,14 @@ def pytest_runtest_call(item: Item) -> None:
raise e raise e
def pytest_runtest_teardown(item: Item, nextitem: Optional[Item]) -> None: def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
_update_current_test_var(item, "teardown") _update_current_test_var(item, "teardown")
item.session._setupstate.teardown_exact(nextitem) item.session._setupstate.teardown_exact(nextitem)
_update_current_test_var(item, None) _update_current_test_var(item, None)
def _update_current_test_var( def _update_current_test_var(
item: Item, when: Optional[Literal["setup", "call", "teardown"]] item: Item, when: Literal["setup", "call", "teardown"] | None
) -> None: ) -> None:
"""Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage. """Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
@ -206,7 +202,7 @@ def _update_current_test_var(
os.environ.pop(var_name) os.environ.pop(var_name)
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]: def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if report.when in ("setup", "teardown"): if report.when in ("setup", "teardown"):
if report.failed: if report.failed:
# category, shortletter, verbose-word # category, shortletter, verbose-word
@ -234,7 +230,7 @@ def call_and_report(
runtest_hook = ihook.pytest_runtest_teardown runtest_hook = ihook.pytest_runtest_teardown
else: else:
assert False, f"Unhandled runtest hook case: {when}" assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,) reraise: tuple[type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False): if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,) reraise += (KeyboardInterrupt,)
call = CallInfo.from_call( call = CallInfo.from_call(
@ -248,7 +244,7 @@ def call_and_report(
return report return report
def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) -> bool: def check_interactive_exception(call: CallInfo[object], report: BaseReport) -> bool:
"""Check whether the call raised an exception that should be reported as """Check whether the call raised an exception that should be reported as
interactive.""" interactive."""
if call.excinfo is None: if call.excinfo is None:
@ -271,9 +267,9 @@ TResult = TypeVar("TResult", covariant=True)
class CallInfo(Generic[TResult]): class CallInfo(Generic[TResult]):
"""Result/Exception info of a function invocation.""" """Result/Exception info of a function invocation."""
_result: Optional[TResult] _result: TResult | None
#: The captured exception of the call, if it raised. #: The captured exception of the call, if it raised.
excinfo: Optional[ExceptionInfo[BaseException]] excinfo: ExceptionInfo[BaseException] | None
#: The system time when the call started, in seconds since the epoch. #: The system time when the call started, in seconds since the epoch.
start: float start: float
#: The system time when the call ended, in seconds since the epoch. #: The system time when the call ended, in seconds since the epoch.
@ -285,8 +281,8 @@ class CallInfo(Generic[TResult]):
def __init__( def __init__(
self, self,
result: Optional[TResult], result: TResult | None,
excinfo: Optional[ExceptionInfo[BaseException]], excinfo: ExceptionInfo[BaseException] | None,
start: float, start: float,
stop: float, stop: float,
duration: float, duration: float,
@ -320,14 +316,13 @@ class CallInfo(Generic[TResult]):
cls, cls,
func: Callable[[], TResult], func: Callable[[], TResult],
when: Literal["collect", "setup", "call", "teardown"], when: Literal["collect", "setup", "call", "teardown"],
reraise: Optional[ reraise: type[BaseException] | tuple[type[BaseException], ...] | None = None,
Union[Type[BaseException], Tuple[Type[BaseException], ...]] ) -> CallInfo[TResult]:
] = None,
) -> "CallInfo[TResult]":
"""Call func, wrapping the result in a CallInfo. """Call func, wrapping the result in a CallInfo.
:param func: :param func:
The function to call. Called without arguments. The function to call. Called without arguments.
:type func: Callable[[], _pytest.runner.TResult]
:param when: :param when:
The phase in which the function is called. The phase in which the function is called.
:param reraise: :param reraise:
@ -338,7 +333,7 @@ class CallInfo(Generic[TResult]):
start = timing.time() start = timing.time()
precise_start = timing.perf_counter() precise_start = timing.perf_counter()
try: try:
result: Optional[TResult] = func() result: TResult | None = func()
except BaseException: except BaseException:
excinfo = ExceptionInfo.from_current() excinfo = ExceptionInfo.from_current()
if reraise is not None and isinstance(excinfo.value, reraise): if reraise is not None and isinstance(excinfo.value, reraise):
@ -369,7 +364,7 @@ def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
def pytest_make_collect_report(collector: Collector) -> CollectReport: def pytest_make_collect_report(collector: Collector) -> CollectReport:
def collect() -> List[Union[Item, Collector]]: def collect() -> list[Item | Collector]:
# Before collecting, if this is a Directory, load the conftests. # Before collecting, if this is a Directory, load the conftests.
# If a conftest import fails to load, it is considered a collection # If a conftest import fails to load, it is considered a collection
# error of the Directory collector. This is why it's done inside of the # error of the Directory collector. This is why it's done inside of the
@ -391,7 +386,7 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
call = CallInfo.from_call( call = CallInfo.from_call(
collect, "collect", reraise=(KeyboardInterrupt, SystemExit) collect, "collect", reraise=(KeyboardInterrupt, SystemExit)
) )
longrepr: Union[None, Tuple[str, int, str], str, TerminalRepr] = None longrepr: None | tuple[str, int, str] | str | TerminalRepr = None
if not call.excinfo: if not call.excinfo:
outcome: Literal["passed", "skipped", "failed"] = "passed" outcome: Literal["passed", "skipped", "failed"] = "passed"
else: else:
@ -485,13 +480,13 @@ class SetupState:
def __init__(self) -> None: def __init__(self) -> None:
# The stack is in the dict insertion order. # The stack is in the dict insertion order.
self.stack: Dict[ self.stack: dict[
Node, Node,
Tuple[ tuple[
# Node's finalizers. # Node's finalizers.
List[Callable[[], object]], list[Callable[[], object]],
# Node's exception, if its setup raised. # Node's exception and original traceback, if its setup raised.
Optional[Union[OutcomeException, Exception]], OutcomeException | Exception | None,
], ],
] = {} ] = {}
@ -526,7 +521,7 @@ class SetupState:
assert node in self.stack, (node, self.stack) assert node in self.stack, (node, self.stack)
self.stack[node][0].append(finalizer) self.stack[node][0].append(finalizer)
def teardown_exact(self, nextitem: Optional[Item]) -> None: def teardown_exact(self, nextitem: Item | None) -> None:
"""Teardown the current stack up until reaching nodes that nextitem """Teardown the current stack up until reaching nodes that nextitem
also descends from. also descends from.
@ -534,7 +529,7 @@ class SetupState:
stack is torn down. stack is torn down.
""" """
needed_collectors = nextitem and nextitem.listchain() or [] needed_collectors = nextitem and nextitem.listchain() or []
exceptions: List[BaseException] = [] exceptions: list[BaseException] = []
while self.stack: while self.stack:
if list(self.stack.keys()) == needed_collectors[: len(self.stack)]: if list(self.stack.keys()) == needed_collectors[: len(self.stack)]:
break break

View File

@ -8,10 +8,11 @@ would cause circular references.
Also this makes the module light to import, as it should. Also this makes the module light to import, as it should.
""" """
from __future__ import annotations
from enum import Enum from enum import Enum
from functools import total_ordering from functools import total_ordering
from typing import Literal from typing import Literal
from typing import Optional
_ScopeName = Literal["session", "package", "module", "class", "function"] _ScopeName = Literal["session", "package", "module", "class", "function"]
@ -38,29 +39,29 @@ class Scope(Enum):
Package: _ScopeName = "package" Package: _ScopeName = "package"
Session: _ScopeName = "session" Session: _ScopeName = "session"
def next_lower(self) -> "Scope": def next_lower(self) -> Scope:
"""Return the next lower scope.""" """Return the next lower scope."""
index = _SCOPE_INDICES[self] index = _SCOPE_INDICES[self]
if index == 0: if index == 0:
raise ValueError(f"{self} is the lower-most scope") raise ValueError(f"{self} is the lower-most scope")
return _ALL_SCOPES[index - 1] return _ALL_SCOPES[index - 1]
def next_higher(self) -> "Scope": def next_higher(self) -> Scope:
"""Return the next higher scope.""" """Return the next higher scope."""
index = _SCOPE_INDICES[self] index = _SCOPE_INDICES[self]
if index == len(_SCOPE_INDICES) - 1: if index == len(_SCOPE_INDICES) - 1:
raise ValueError(f"{self} is the upper-most scope") raise ValueError(f"{self} is the upper-most scope")
return _ALL_SCOPES[index + 1] return _ALL_SCOPES[index + 1]
def __lt__(self, other: "Scope") -> bool: def __lt__(self, other: Scope) -> bool:
self_index = _SCOPE_INDICES[self] self_index = _SCOPE_INDICES[self]
other_index = _SCOPE_INDICES[other] other_index = _SCOPE_INDICES[other]
return self_index < other_index return self_index < other_index
@classmethod @classmethod
def from_user( def from_user(
cls, scope_name: _ScopeName, descr: str, where: Optional[str] = None cls, scope_name: _ScopeName, descr: str, where: str | None = None
) -> "Scope": ) -> Scope:
""" """
Given a scope name from the user, return the equivalent Scope enum. Should be used Given a scope name from the user, return the equivalent Scope enum. Should be used
whenever we want to convert a user provided scope name to its enum object. whenever we want to convert a user provided scope name to its enum object.

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Generator from typing import Generator
from typing import Optional
from typing import Union
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.config import Config from _pytest.config import Config
@ -96,7 +96,7 @@ def _show_fixture_action(
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setuponly: if config.option.setuponly:
config.option.setupshow = True config.option.setupshow = True
return None return None

View File

@ -1,5 +1,4 @@
from typing import Optional from __future__ import annotations
from typing import Union
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
@ -23,7 +22,7 @@ def pytest_addoption(parser: Parser) -> None:
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup( def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest fixturedef: FixtureDef[object], request: SubRequest
) -> Optional[object]: ) -> object | None:
# Will return a dummy fixture if the setuponly option is provided. # Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan: if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request) my_cache_key = fixturedef.cache_key(request)
@ -33,7 +32,7 @@ def pytest_fixture_setup(
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]: def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setupplan: if config.option.setupplan:
config.option.setuponly = True config.option.setuponly = True
config.option.setupshow = True config.option.setupshow = True

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Support for skip/xfail functions and markers.""" """Support for skip/xfail functions and markers."""
from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import dataclasses import dataclasses
import os import os
@ -9,8 +11,6 @@ import sys
import traceback import traceback
from typing import Generator from typing import Generator
from typing import Optional from typing import Optional
from typing import Tuple
from typing import Type
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import hookimpl from _pytest.config import hookimpl
@ -84,7 +84,7 @@ def pytest_configure(config: Config) -> None:
) )
def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool, str]: def evaluate_condition(item: Item, mark: Mark, condition: object) -> tuple[bool, str]:
"""Evaluate a single skipif/xfail condition. """Evaluate a single skipif/xfail condition.
If an old-style string condition is given, it is eval()'d, otherwise the If an old-style string condition is given, it is eval()'d, otherwise the
@ -164,7 +164,7 @@ class Skip:
reason: str = "unconditional skip" reason: str = "unconditional skip"
def evaluate_skip_marks(item: Item) -> Optional[Skip]: def evaluate_skip_marks(item: Item) -> Skip | None:
"""Evaluate skip and skipif marks on item, returning Skip if triggered.""" """Evaluate skip and skipif marks on item, returning Skip if triggered."""
for mark in item.iter_markers(name="skipif"): for mark in item.iter_markers(name="skipif"):
if "condition" not in mark.kwargs: if "condition" not in mark.kwargs:
@ -201,10 +201,10 @@ class Xfail:
reason: str reason: str
run: bool run: bool
strict: bool strict: bool
raises: Optional[Tuple[Type[BaseException], ...]] raises: tuple[type[BaseException], ...] | None
def evaluate_xfail_marks(item: Item) -> Optional[Xfail]: def evaluate_xfail_marks(item: Item) -> Xfail | None:
"""Evaluate xfail marks on item, returning Xfail if triggered.""" """Evaluate xfail marks on item, returning Xfail if triggered."""
for mark in item.iter_markers(name="xfail"): for mark in item.iter_markers(name="xfail"):
run = mark.kwargs.get("run", True) run = mark.kwargs.get("run", True)
@ -292,7 +292,7 @@ def pytest_runtest_makereport(
return rep return rep
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]: def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if hasattr(report, "wasxfail"): if hasattr(report, "wasxfail"):
if report.skipped: if report.skipped:
return "xfailed", "x", "XFAIL" return "xfailed", "x", "XFAIL"

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import Dict
from typing import Generic from typing import Generic
from typing import TypeVar from typing import TypeVar
from typing import Union
__all__ = ["Stash", "StashKey"] __all__ = ["Stash", "StashKey"]
@ -70,7 +70,7 @@ class Stash:
__slots__ = ("_storage",) __slots__ = ("_storage",)
def __init__(self) -> None: def __init__(self) -> None:
self._storage: Dict[StashKey[Any], object] = {} self._storage: dict[StashKey[Any], object] = {}
def __setitem__(self, key: StashKey[T], value: T) -> None: def __setitem__(self, key: StashKey[T], value: T) -> None:
"""Set a value for key.""" """Set a value for key."""
@ -83,7 +83,7 @@ class Stash:
""" """
return cast(T, self._storage[key]) return cast(T, self._storage[key])
def get(self, key: StashKey[T], default: D) -> Union[T, D]: def get(self, key: StashKey[T], default: D) -> T | D:
"""Get the value for key, or return default if the key wasn't set """Get the value for key, or return default if the key wasn't set
before.""" before."""
try: try:

View File

@ -1,5 +1,5 @@
from typing import List from __future__ import annotations
from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from _pytest import nodes from _pytest import nodes
@ -13,6 +13,7 @@ import pytest
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest.cacheprovider import Cache from _pytest.cacheprovider import Cache
STEPWISE_CACHE_DIR = "cache/stepwise" STEPWISE_CACHE_DIR = "cache/stepwise"
@ -60,18 +61,18 @@ def pytest_sessionfinish(session: Session) -> None:
class StepwisePlugin: class StepwisePlugin:
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
self.config = config self.config = config
self.session: Optional[Session] = None self.session: Session | None = None
self.report_status = "" self.report_status = ""
assert config.cache is not None assert config.cache is not None
self.cache: Cache = config.cache self.cache: Cache = config.cache
self.lastfailed: Optional[str] = self.cache.get(STEPWISE_CACHE_DIR, None) self.lastfailed: str | None = self.cache.get(STEPWISE_CACHE_DIR, None)
self.skip: bool = config.getoption("stepwise_skip") self.skip: bool = config.getoption("stepwise_skip")
def pytest_sessionstart(self, session: Session) -> None: def pytest_sessionstart(self, session: Session) -> None:
self.session = session self.session = session
def pytest_collection_modifyitems( def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item] self, config: Config, items: list[nodes.Item]
) -> None: ) -> None:
if not self.lastfailed: if not self.lastfailed:
self.report_status = "no previously failed tests, not skipping." self.report_status = "no previously failed tests, not skipping."
@ -118,7 +119,7 @@ class StepwisePlugin:
if report.nodeid == self.lastfailed: if report.nodeid == self.lastfailed:
self.lastfailed = None self.lastfailed = None
def pytest_report_collectionfinish(self) -> Optional[str]: def pytest_report_collectionfinish(self) -> str | None:
if self.config.getoption("verbose") >= 0 and self.report_status: if self.config.getoption("verbose") >= 0 and self.report_status:
return f"stepwise: {self.report_status}" return f"stepwise: {self.report_status}"
return None return None

View File

@ -4,6 +4,8 @@
This is a good source for looking at the various reporting hooks. This is a good source for looking at the various reporting hooks.
""" """
from __future__ import annotations
import argparse import argparse
from collections import Counter from collections import Counter
import dataclasses import dataclasses
@ -17,20 +19,14 @@ import textwrap
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import ClassVar from typing import ClassVar
from typing import Dict
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import NamedTuple from typing import NamedTuple
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set
from typing import TextIO from typing import TextIO
from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import warnings import warnings
import pluggy import pluggy
@ -90,7 +86,7 @@ class MoreQuietAction(argparse.Action):
dest: str, dest: str,
default: object = None, default: object = None,
required: bool = False, required: bool = False,
help: Optional[str] = None, help: str | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
option_strings=option_strings, option_strings=option_strings,
@ -105,8 +101,8 @@ class MoreQuietAction(argparse.Action):
self, self,
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
namespace: argparse.Namespace, namespace: argparse.Namespace,
values: Union[str, Sequence[object], None], values: str | Sequence[object] | None,
option_string: Optional[str] = None, option_string: str | None = None,
) -> None: ) -> None:
new_count = getattr(namespace, self.dest, 0) - 1 new_count = getattr(namespace, self.dest, 0) - 1
setattr(namespace, self.dest, new_count) setattr(namespace, self.dest, new_count)
@ -131,7 +127,7 @@ class TestShortLogReport(NamedTuple):
category: str category: str
letter: str letter: str
word: Union[str, Tuple[str, Mapping[str, bool]]] word: str | tuple[str, Mapping[str, bool]]
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
@ -304,7 +300,7 @@ def getreportopt(config: Config) -> str:
@hookimpl(trylast=True) # after _pytest.runner @hookimpl(trylast=True) # after _pytest.runner
def pytest_report_teststatus(report: BaseReport) -> Tuple[str, str, str]: def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str]:
letter = "F" letter = "F"
if report.passed: if report.passed:
letter = "." letter = "."
@ -332,12 +328,12 @@ class WarningReport:
""" """
message: str message: str
nodeid: Optional[str] = None nodeid: str | None = None
fslocation: Optional[Tuple[str, int]] = None fslocation: tuple[str, int] | None = None
count_towards_summary: ClassVar = True count_towards_summary: ClassVar = True
def get_location(self, config: Config) -> Optional[str]: def get_location(self, config: Config) -> str | None:
"""Return the more user-friendly information about the location of a warning, or None.""" """Return the more user-friendly information about the location of a warning, or None."""
if self.nodeid: if self.nodeid:
return self.nodeid return self.nodeid
@ -350,31 +346,31 @@ class WarningReport:
@final @final
class TerminalReporter: class TerminalReporter:
def __init__(self, config: Config, file: Optional[TextIO] = None) -> None: def __init__(self, config: Config, file: TextIO | None = None) -> None:
import _pytest.config import _pytest.config
self.config = config self.config = config
self._numcollected = 0 self._numcollected = 0
self._session: Optional[Session] = None self._session: Session | None = None
self._showfspath: Optional[bool] = None self._showfspath: bool | None = None
self.stats: Dict[str, List[Any]] = {} self.stats: dict[str, list[Any]] = {}
self._main_color: Optional[str] = None self._main_color: str | None = None
self._known_types: Optional[List[str]] = None self._known_types: list[str] | None = None
self.startpath = config.invocation_params.dir self.startpath = config.invocation_params.dir
if file is None: if file is None:
file = sys.stdout file = sys.stdout
self._tw = _pytest.config.create_terminal_writer(config, file) self._tw = _pytest.config.create_terminal_writer(config, file)
self._screen_width = self._tw.fullwidth self._screen_width = self._tw.fullwidth
self.currentfspath: Union[None, Path, str, int] = None self.currentfspath: None | Path | str | int = None
self.reportchars = getreportopt(config) self.reportchars = getreportopt(config)
self.hasmarkup = self._tw.hasmarkup self.hasmarkup = self._tw.hasmarkup
self.isatty = file.isatty() self.isatty = file.isatty()
self._progress_nodeids_reported: Set[str] = set() self._progress_nodeids_reported: set[str] = set()
self._show_progress_info = self._determine_show_progress_info() self._show_progress_info = self._determine_show_progress_info()
self._collect_report_last_write: Optional[float] = None self._collect_report_last_write: float | None = None
self._already_displayed_warnings: Optional[int] = None self._already_displayed_warnings: int | None = None
self._keyboardinterrupt_memo: Optional[ExceptionRepr] = None self._keyboardinterrupt_memo: ExceptionRepr | None = None
def _determine_show_progress_info(self) -> Literal["progress", "count", False]: def _determine_show_progress_info(self) -> Literal["progress", "count", False]:
"""Return whether we should display progress information based on the current config.""" """Return whether we should display progress information based on the current config."""
@ -421,7 +417,7 @@ class TerminalReporter:
return self._showfspath return self._showfspath
@showfspath.setter @showfspath.setter
def showfspath(self, value: Optional[bool]) -> None: def showfspath(self, value: bool | None) -> None:
self._showfspath = value self._showfspath = value
@property @property
@ -485,7 +481,7 @@ class TerminalReporter:
def flush(self) -> None: def flush(self) -> None:
self._tw.flush() self._tw.flush()
def write_line(self, line: Union[str, bytes], **markup: bool) -> None: def write_line(self, line: str | bytes, **markup: bool) -> None:
if not isinstance(line, str): if not isinstance(line, str):
line = str(line, errors="replace") line = str(line, errors="replace")
self.ensure_newline() self.ensure_newline()
@ -512,8 +508,8 @@ class TerminalReporter:
def write_sep( def write_sep(
self, self,
sep: str, sep: str,
title: Optional[str] = None, title: str | None = None,
fullwidth: Optional[int] = None, fullwidth: int | None = None,
**markup: bool, **markup: bool,
) -> None: ) -> None:
self.ensure_newline() self.ensure_newline()
@ -563,7 +559,7 @@ class TerminalReporter:
self._add_stats("deselected", items) self._add_stats("deselected", items)
def pytest_runtest_logstart( def pytest_runtest_logstart(
self, nodeid: str, location: Tuple[str, Optional[int], str] self, nodeid: str, location: tuple[str, int | None, str]
) -> None: ) -> None:
# Ensure that the path is printed before the # Ensure that the path is printed before the
# 1st test of a module starts running. # 1st test of a module starts running.
@ -757,7 +753,7 @@ class TerminalReporter:
self.write_line(line) self.write_line(line)
@hookimpl(trylast=True) @hookimpl(trylast=True)
def pytest_sessionstart(self, session: "Session") -> None: def pytest_sessionstart(self, session: Session) -> None:
self._session = session self._session = session
self._sessionstarttime = timing.time() self._sessionstarttime = timing.time()
if not self.showheader: if not self.showheader:
@ -784,7 +780,7 @@ class TerminalReporter:
self._write_report_lines_from_hooks(lines) self._write_report_lines_from_hooks(lines)
def _write_report_lines_from_hooks( def _write_report_lines_from_hooks(
self, lines: Sequence[Union[str, Sequence[str]]] self, lines: Sequence[str | Sequence[str]]
) -> None: ) -> None:
for line_or_lines in reversed(lines): for line_or_lines in reversed(lines):
if isinstance(line_or_lines, str): if isinstance(line_or_lines, str):
@ -793,14 +789,14 @@ class TerminalReporter:
for line in line_or_lines: for line in line_or_lines:
self.write_line(line) self.write_line(line)
def pytest_report_header(self, config: Config) -> List[str]: def pytest_report_header(self, config: Config) -> list[str]:
result = [f"rootdir: {config.rootpath}"] result = [f"rootdir: {config.rootpath}"]
if config.inipath: if config.inipath:
result.append("configfile: " + bestrelpath(config.rootpath, config.inipath)) result.append("configfile: " + bestrelpath(config.rootpath, config.inipath))
if config.args_source == Config.ArgsSource.TESTPATHS: if config.args_source == Config.ArgsSource.TESTPATHS:
testpaths: List[str] = config.getini("testpaths") testpaths: list[str] = config.getini("testpaths")
result.append("testpaths: {}".format(", ".join(testpaths))) result.append("testpaths: {}".format(", ".join(testpaths)))
plugininfo = config.pluginmanager.list_plugin_distinfo() plugininfo = config.pluginmanager.list_plugin_distinfo()
@ -808,7 +804,7 @@ class TerminalReporter:
result.append("plugins: %s" % ", ".join(_plugin_nameversions(plugininfo))) result.append("plugins: %s" % ", ".join(_plugin_nameversions(plugininfo)))
return result return result
def pytest_collection_finish(self, session: "Session") -> None: def pytest_collection_finish(self, session: Session) -> None:
self.report_collect(True) self.report_collect(True)
lines = self.config.hook.pytest_report_collectionfinish( lines = self.config.hook.pytest_report_collectionfinish(
@ -841,7 +837,7 @@ class TerminalReporter:
for item in items: for item in items:
self._tw.line(item.nodeid) self._tw.line(item.nodeid)
return return
stack: List[Node] = [] stack: list[Node] = []
indent = "" indent = ""
for item in items: for item in items:
needed_collectors = item.listchain()[1:] # strip root node needed_collectors = item.listchain()[1:] # strip root node
@ -862,7 +858,7 @@ class TerminalReporter:
@hookimpl(wrapper=True) @hookimpl(wrapper=True)
def pytest_sessionfinish( def pytest_sessionfinish(
self, session: "Session", exitstatus: Union[int, ExitCode] self, session: Session, exitstatus: int | ExitCode
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
result = yield result = yield
self._tw.line("") self._tw.line("")
@ -926,7 +922,7 @@ class TerminalReporter:
) )
def _locationline( def _locationline(
self, nodeid: str, fspath: str, lineno: Optional[int], domain: str self, nodeid: str, fspath: str, lineno: int | None, domain: str
) -> str: ) -> str:
def mkrel(nodeid: str) -> str: def mkrel(nodeid: str) -> str:
line = self.config.cwd_relative_nodeid(nodeid) line = self.config.cwd_relative_nodeid(nodeid)
@ -971,7 +967,7 @@ class TerminalReporter:
def summary_warnings(self) -> None: def summary_warnings(self) -> None:
if self.hasopt("w"): if self.hasopt("w"):
all_warnings: Optional[List[WarningReport]] = self.stats.get("warnings") all_warnings: list[WarningReport] | None = self.stats.get("warnings")
if not all_warnings: if not all_warnings:
return return
@ -984,11 +980,11 @@ class TerminalReporter:
if not warning_reports: if not warning_reports:
return return
reports_grouped_by_message: Dict[str, List[WarningReport]] = {} reports_grouped_by_message: dict[str, list[WarningReport]] = {}
for wr in warning_reports: for wr in warning_reports:
reports_grouped_by_message.setdefault(wr.message, []).append(wr) reports_grouped_by_message.setdefault(wr.message, []).append(wr)
def collapsed_location_report(reports: List[WarningReport]) -> str: def collapsed_location_report(reports: list[WarningReport]) -> str:
locations = [] locations = []
for w in reports: for w in reports:
location = w.get_location(self.config) location = w.get_location(self.config)
@ -1034,7 +1030,7 @@ class TerminalReporter:
) -> None: ) -> None:
if self.config.option.tbstyle != "no": if self.config.option.tbstyle != "no":
if self.hasopt(needed_opt): if self.hasopt(needed_opt):
reports: List[TestReport] = self.getreports(which_reports) reports: list[TestReport] = self.getreports(which_reports)
if not reports: if not reports:
return return
self.write_sep("=", sep_title) self.write_sep("=", sep_title)
@ -1045,7 +1041,7 @@ class TerminalReporter:
self._outrep_summary(rep) self._outrep_summary(rep)
self._handle_teardown_sections(rep.nodeid) self._handle_teardown_sections(rep.nodeid)
def _get_teardown_reports(self, nodeid: str) -> List[TestReport]: def _get_teardown_reports(self, nodeid: str) -> list[TestReport]:
reports = self.getreports("") reports = self.getreports("")
return [ return [
report report
@ -1074,14 +1070,17 @@ class TerminalReporter:
self.summary_failures_combined("failed", "FAILURES") self.summary_failures_combined("failed", "FAILURES")
def summary_xfailures(self) -> None: def summary_xfailures(self) -> None:
self.summary_failures_combined("xfailed", "XFAILURES", "x") self.summary_failures_combined("xfailed", "XFAILURES", needed_opt="x")
def summary_failures_combined( def summary_failures_combined(
self, which_reports: str, sep_title: str, needed_opt: Optional[str] = None self,
which_reports: str,
sep_title: str,
needed_opt: str | None = None,
) -> None: ) -> None:
if self.config.option.tbstyle != "no": if self.config.option.tbstyle != "no":
if not needed_opt or self.hasopt(needed_opt): if not needed_opt or self.hasopt(needed_opt):
reports: List[BaseReport] = self.getreports(which_reports) reports: list[BaseReport] = self.getreports(which_reports)
if not reports: if not reports:
return return
self.write_sep("=", sep_title) self.write_sep("=", sep_title)
@ -1098,7 +1097,7 @@ class TerminalReporter:
def summary_errors(self) -> None: def summary_errors(self) -> None:
if self.config.option.tbstyle != "no": if self.config.option.tbstyle != "no":
reports: List[BaseReport] = self.getreports("error") reports: list[BaseReport] = self.getreports("error")
if not reports: if not reports:
return return
self.write_sep("=", "ERRORS") self.write_sep("=", "ERRORS")
@ -1165,7 +1164,7 @@ class TerminalReporter:
if not self.reportchars: if not self.reportchars:
return return
def show_simple(lines: List[str], *, stat: str) -> None: def show_simple(lines: list[str], *, stat: str) -> None:
failed = self.stats.get(stat, []) failed = self.stats.get(stat, [])
if not failed: if not failed:
return return
@ -1177,7 +1176,7 @@ class TerminalReporter:
) )
lines.append(line) lines.append(line)
def show_xfailed(lines: List[str]) -> None: def show_xfailed(lines: list[str]) -> None:
xfailed = self.stats.get("xfailed", []) xfailed = self.stats.get("xfailed", [])
for rep in xfailed: for rep in xfailed:
verbose_word = rep._get_verbose_word(self.config) verbose_word = rep._get_verbose_word(self.config)
@ -1192,7 +1191,7 @@ class TerminalReporter:
lines.append(line) lines.append(line)
def show_xpassed(lines: List[str]) -> None: def show_xpassed(lines: list[str]) -> None:
xpassed = self.stats.get("xpassed", []) xpassed = self.stats.get("xpassed", [])
for rep in xpassed: for rep in xpassed:
verbose_word = rep._get_verbose_word(self.config) verbose_word = rep._get_verbose_word(self.config)
@ -1206,8 +1205,8 @@ class TerminalReporter:
line += " - " + str(reason) line += " - " + str(reason)
lines.append(line) lines.append(line)
def show_skipped(lines: List[str]) -> None: def show_skipped(lines: list[str]) -> None:
skipped: List[CollectReport] = self.stats.get("skipped", []) skipped: list[CollectReport] = self.stats.get("skipped", [])
fskips = _folded_skips(self.startpath, skipped) if skipped else [] fskips = _folded_skips(self.startpath, skipped) if skipped else []
if not fskips: if not fskips:
return return
@ -1226,7 +1225,7 @@ class TerminalReporter:
else: else:
lines.append("%s [%d] %s: %s" % (markup_word, num, fspath, reason)) lines.append("%s [%d] %s: %s" % (markup_word, num, fspath, reason))
REPORTCHAR_ACTIONS: Mapping[str, Callable[[List[str]], None]] = { REPORTCHAR_ACTIONS: Mapping[str, Callable[[list[str]], None]] = {
"x": show_xfailed, "x": show_xfailed,
"X": show_xpassed, "X": show_xpassed,
"f": partial(show_simple, stat="failed"), "f": partial(show_simple, stat="failed"),
@ -1235,7 +1234,7 @@ class TerminalReporter:
"E": partial(show_simple, stat="error"), "E": partial(show_simple, stat="error"),
} }
lines: List[str] = [] lines: list[str] = []
for char in self.reportchars: for char in self.reportchars:
action = REPORTCHAR_ACTIONS.get(char) action = REPORTCHAR_ACTIONS.get(char)
if action: # skipping e.g. "P" (passed with output) here. if action: # skipping e.g. "P" (passed with output) here.
@ -1246,7 +1245,7 @@ class TerminalReporter:
for line in lines: for line in lines:
self.write_line(line) self.write_line(line)
def _get_main_color(self) -> Tuple[str, List[str]]: def _get_main_color(self) -> tuple[str, list[str]]:
if self._main_color is None or self._known_types is None or self._is_last_item: if self._main_color is None or self._known_types is None or self._is_last_item:
self._set_main_color() self._set_main_color()
assert self._main_color assert self._main_color
@ -1266,7 +1265,7 @@ class TerminalReporter:
return main_color return main_color
def _set_main_color(self) -> None: def _set_main_color(self) -> None:
unknown_types: List[str] = [] unknown_types: list[str] = []
for found_type in self.stats: for found_type in self.stats:
if found_type: # setup/teardown reports have an empty key, ignore them if found_type: # setup/teardown reports have an empty key, ignore them
if found_type not in KNOWN_TYPES and found_type not in unknown_types: if found_type not in KNOWN_TYPES and found_type not in unknown_types:
@ -1274,7 +1273,7 @@ class TerminalReporter:
self._known_types = list(KNOWN_TYPES) + unknown_types self._known_types = list(KNOWN_TYPES) + unknown_types
self._main_color = self._determine_main_color(bool(unknown_types)) self._main_color = self._determine_main_color(bool(unknown_types))
def build_summary_stats_line(self) -> Tuple[List[Tuple[str, Dict[str, bool]]], str]: def build_summary_stats_line(self) -> tuple[list[tuple[str, dict[str, bool]]], str]:
""" """
Build the parts used in the last summary stats line. Build the parts used in the last summary stats line.
@ -1299,14 +1298,14 @@ class TerminalReporter:
else: else:
return self._build_normal_summary_stats_line() return self._build_normal_summary_stats_line()
def _get_reports_to_display(self, key: str) -> List[Any]: def _get_reports_to_display(self, key: str) -> list[Any]:
"""Get test/collection reports for the given status key, such as `passed` or `error`.""" """Get test/collection reports for the given status key, such as `passed` or `error`."""
reports = self.stats.get(key, []) reports = self.stats.get(key, [])
return [x for x in reports if getattr(x, "count_towards_summary", True)] return [x for x in reports if getattr(x, "count_towards_summary", True)]
def _build_normal_summary_stats_line( def _build_normal_summary_stats_line(
self, self,
) -> Tuple[List[Tuple[str, Dict[str, bool]]], str]: ) -> tuple[list[tuple[str, dict[str, bool]]], str]:
main_color, known_types = self._get_main_color() main_color, known_types = self._get_main_color()
parts = [] parts = []
@ -1325,7 +1324,7 @@ class TerminalReporter:
def _build_collect_only_summary_stats_line( def _build_collect_only_summary_stats_line(
self, self,
) -> Tuple[List[Tuple[str, Dict[str, bool]]], str]: ) -> tuple[list[tuple[str, dict[str, bool]]], str]:
deselected = len(self._get_reports_to_display("deselected")) deselected = len(self._get_reports_to_display("deselected"))
errors = len(self._get_reports_to_display("error")) errors = len(self._get_reports_to_display("error"))
@ -1366,7 +1365,7 @@ def _get_node_id_with_markup(tw: TerminalWriter, config: Config, rep: BaseReport
return path return path
def _format_trimmed(format: str, msg: str, available_width: int) -> Optional[str]: def _format_trimmed(format: str, msg: str, available_width: int) -> str | None:
"""Format msg into format, ellipsizing it if doesn't fit in available_width. """Format msg into format, ellipsizing it if doesn't fit in available_width.
Returns None if even the ellipsis can't fit. Returns None if even the ellipsis can't fit.
@ -1392,7 +1391,7 @@ def _format_trimmed(format: str, msg: str, available_width: int) -> Optional[str
def _get_line_with_reprcrash_message( def _get_line_with_reprcrash_message(
config: Config, rep: BaseReport, tw: TerminalWriter, word_markup: Dict[str, bool] config: Config, rep: BaseReport, tw: TerminalWriter, word_markup: dict[str, bool]
) -> str: ) -> str:
"""Get summary line for a report, trying to add reprcrash message.""" """Get summary line for a report, trying to add reprcrash message."""
verbose_word = rep._get_verbose_word(config) verbose_word = rep._get_verbose_word(config)
@ -1422,8 +1421,8 @@ def _get_line_with_reprcrash_message(
def _folded_skips( def _folded_skips(
startpath: Path, startpath: Path,
skipped: Sequence[CollectReport], skipped: Sequence[CollectReport],
) -> List[Tuple[int, str, Optional[int], str]]: ) -> list[tuple[int, str, int | None, str]]:
d: Dict[Tuple[str, Optional[int], str], List[CollectReport]] = {} d: dict[tuple[str, int | None, str], list[CollectReport]] = {}
for event in skipped: for event in skipped:
assert event.longrepr is not None assert event.longrepr is not None
assert isinstance(event.longrepr, tuple), (event, event.longrepr) assert isinstance(event.longrepr, tuple), (event, event.longrepr)
@ -1440,11 +1439,11 @@ def _folded_skips(
and "skip" in keywords and "skip" in keywords
and "pytestmark" not in keywords and "pytestmark" not in keywords
): ):
key: Tuple[str, Optional[int], str] = (fspath, None, reason) key: tuple[str, int | None, str] = (fspath, None, reason)
else: else:
key = (fspath, lineno, reason) key = (fspath, lineno, reason)
d.setdefault(key, []).append(event) d.setdefault(key, []).append(event)
values: List[Tuple[int, str, Optional[int], str]] = [] values: list[tuple[int, str, int | None, str]] = []
for key, events in d.items(): for key, events in d.items():
values.append((len(events), *key)) values.append((len(events), *key))
return values return values
@ -1459,7 +1458,7 @@ _color_for_type = {
_color_for_type_default = "yellow" _color_for_type_default = "yellow"
def pluralize(count: int, noun: str) -> Tuple[int, str]: def pluralize(count: int, noun: str) -> tuple[int, str]:
# No need to pluralize words such as `failed` or `passed`. # No need to pluralize words such as `failed` or `passed`.
if noun not in ["error", "warnings", "test"]: if noun not in ["error", "warnings", "test"]:
return count, noun return count, noun
@ -1472,8 +1471,8 @@ def pluralize(count: int, noun: str) -> Tuple[int, str]:
return count, noun + "s" if count != 1 else noun return count, noun + "s" if count != 1 else noun
def _plugin_nameversions(plugininfo) -> List[str]: def _plugin_nameversions(plugininfo) -> list[str]:
values: List[str] = [] values: list[str] = []
for plugin, dist in plugininfo: for plugin, dist in plugininfo:
# Gets us name and version! # Gets us name and version!
name = f"{dist.project_name}-{dist.version}" name = f"{dist.project_name}-{dist.version}"

View File

@ -1,16 +1,21 @@
from __future__ import annotations
import threading import threading
import traceback import traceback
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
from typing import Optional from typing import TYPE_CHECKING
from typing import Type
import warnings import warnings
import pytest import pytest
if TYPE_CHECKING:
from typing_extensions import Self
# Copied from cpython/Lib/test/support/threading_helper.py, with modifications. # Copied from cpython/Lib/test/support/threading_helper.py, with modifications.
class catch_threading_exception: class catch_threading_exception:
"""Context manager catching threading.Thread exception using """Context manager catching threading.Thread exception using
@ -34,22 +39,22 @@ class catch_threading_exception:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.args: Optional["threading.ExceptHookArgs"] = None self.args: threading.ExceptHookArgs | None = None
self._old_hook: Optional[Callable[["threading.ExceptHookArgs"], Any]] = None self._old_hook: Callable[[threading.ExceptHookArgs], Any] | None = None
def _hook(self, args: "threading.ExceptHookArgs") -> None: def _hook(self, args: threading.ExceptHookArgs) -> None:
self.args = args self.args = args
def __enter__(self) -> "catch_threading_exception": def __enter__(self) -> Self:
self._old_hook = threading.excepthook self._old_hook = threading.excepthook
threading.excepthook = self._hook threading.excepthook = self._hook
return self return self
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
assert self._old_hook is not None assert self._old_hook is not None
threading.excepthook = self._old_hook threading.excepthook = self._old_hook

View File

@ -6,6 +6,8 @@ pytest runtime information (issue #185).
Fixture "mock_timing" also interacts with this module for pytest's own tests. Fixture "mock_timing" also interacts with this module for pytest's own tests.
""" """
from __future__ import annotations
from time import perf_counter from time import perf_counter
from time import sleep from time import sleep
from time import time from time import time

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Support for providing temporary directories to test functions.""" """Support for providing temporary directories to test functions."""
from __future__ import annotations
import dataclasses import dataclasses
import os import os
from pathlib import Path from pathlib import Path
@ -12,8 +14,6 @@ from typing import Dict
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import Literal from typing import Literal
from typing import Optional
from typing import Union
from .pathlib import cleanup_dead_symlinks from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT from .pathlib import LOCK_TIMEOUT
@ -46,20 +46,20 @@ class TempPathFactory:
The base directory can be configured using the ``--basetemp`` option. The base directory can be configured using the ``--basetemp`` option.
""" """
_given_basetemp: Optional[Path] _given_basetemp: Path | None
# pluggy TagTracerSub, not currently exposed, so Any. # pluggy TagTracerSub, not currently exposed, so Any.
_trace: Any _trace: Any
_basetemp: Optional[Path] _basetemp: Path | None
_retention_count: int _retention_count: int
_retention_policy: RetentionType _retention_policy: RetentionType
def __init__( def __init__(
self, self,
given_basetemp: Optional[Path], given_basetemp: Path | None,
retention_count: int, retention_count: int,
retention_policy: RetentionType, retention_policy: RetentionType,
trace, trace,
basetemp: Optional[Path] = None, basetemp: Path | None = None,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> None: ) -> None:
@ -82,7 +82,7 @@ class TempPathFactory:
config: Config, config: Config,
*, *,
_ispytest: bool = False, _ispytest: bool = False,
) -> "TempPathFactory": ) -> TempPathFactory:
"""Create a factory according to pytest configuration. """Create a factory according to pytest configuration.
:meta private: :meta private:
@ -198,7 +198,7 @@ class TempPathFactory:
return basetemp return basetemp
def get_user() -> Optional[str]: def get_user() -> str | None:
"""Return the current user name, or None if getuser() does not work """Return the current user name, or None if getuser() does not work
in the current environment (see #1010).""" in the current environment (see #1010)."""
try: try:
@ -286,7 +286,7 @@ def tmp_path(
del request.node.stash[tmppath_result_key] del request.node.stash[tmppath_result_key]
def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]): def pytest_sessionfinish(session, exitstatus: int | ExitCode):
"""After each session, remove base directory if all the tests passed, """After each session, remove base directory if all the tests passed,
the policy is "failed", and the basetemp is not specified by a user. the policy is "failed", and the basetemp is not specified by a user.
""" """
@ -317,6 +317,6 @@ def pytest_runtest_makereport(
) -> Generator[None, TestReport, TestReport]: ) -> Generator[None, TestReport, TestReport]:
rep = yield rep = yield
assert rep.when is not None assert rep.when is not None
empty: Dict[str, bool] = {} empty: dict[str, bool] = {}
item.stash.setdefault(tmppath_result_key, empty)[rep.when] = rep.passed item.stash.setdefault(tmppath_result_key, empty)[rep.when] = rep.passed
return rep return rep

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Discover and run std-library "unittest" style tests.""" """Discover and run std-library "unittest" style tests."""
from __future__ import annotations
import sys import sys
import traceback import traceback
import types import types
@ -8,8 +10,6 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
from typing import Iterable from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Type from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -47,8 +47,8 @@ if TYPE_CHECKING:
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: Union[Module, Class], name: str, obj: object collector: Module | Class, name: str, obj: object
) -> Optional["UnitTestCase"]: ) -> UnitTestCase | None:
# Has unittest been imported and is obj a subclass of its TestCase? # Has unittest been imported and is obj a subclass of its TestCase?
try: try:
ut = sys.modules["unittest"] ut = sys.modules["unittest"]
@ -74,7 +74,7 @@ class UnitTestCase(Class):
# it. # it.
return self.obj("runTest") return self.obj("runTest")
def collect(self) -> Iterable[Union[Item, Collector]]: def collect(self) -> Iterable[Item | Collector]:
from unittest import TestLoader from unittest import TestLoader
cls = self.obj cls = self.obj
@ -194,7 +194,7 @@ class UnitTestCase(Class):
class TestCaseFunction(Function): class TestCaseFunction(Function):
nofuncargs = True nofuncargs = True
_excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None _excinfo: list[_pytest._code.ExceptionInfo[BaseException]] | None = None
def _getinstance(self): def _getinstance(self):
assert isinstance(self.parent, UnitTestCase) assert isinstance(self.parent, UnitTestCase)
@ -208,7 +208,7 @@ class TestCaseFunction(Function):
def setup(self) -> None: def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()'). # A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Optional[Callable[[], None]] = None self._explicit_tearDown: Callable[[], None] | None = None
super().setup() super().setup()
def teardown(self) -> None: def teardown(self) -> None:
@ -219,10 +219,10 @@ class TestCaseFunction(Function):
del self._instance del self._instance
super().teardown() super().teardown()
def startTest(self, testcase: "unittest.TestCase") -> None: def startTest(self, testcase: unittest.TestCase) -> None:
pass pass
def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None: def _addexcinfo(self, rawexcinfo: _SysExcInfoType) -> None:
# Unwrap potential exception info (see twisted trial support below). # Unwrap potential exception info (see twisted trial support below).
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo) rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo)
try: try:
@ -258,7 +258,7 @@ class TestCaseFunction(Function):
self.__dict__.setdefault("_excinfo", []).append(excinfo) self.__dict__.setdefault("_excinfo", []).append(excinfo)
def addError( def addError(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType" self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType
) -> None: ) -> None:
try: try:
if isinstance(rawexcinfo[1], exit.Exception): if isinstance(rawexcinfo[1], exit.Exception):
@ -268,11 +268,11 @@ class TestCaseFunction(Function):
self._addexcinfo(rawexcinfo) self._addexcinfo(rawexcinfo)
def addFailure( def addFailure(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType" self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType
) -> None: ) -> None:
self._addexcinfo(rawexcinfo) self._addexcinfo(rawexcinfo)
def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None: def addSkip(self, testcase: unittest.TestCase, reason: str) -> None:
try: try:
raise pytest.skip.Exception(reason, _use_item_location=True) raise pytest.skip.Exception(reason, _use_item_location=True)
except skip.Exception: except skip.Exception:
@ -280,8 +280,8 @@ class TestCaseFunction(Function):
def addExpectedFailure( def addExpectedFailure(
self, self,
testcase: "unittest.TestCase", testcase: unittest.TestCase,
rawexcinfo: "_SysExcInfoType", rawexcinfo: _SysExcInfoType,
reason: str = "", reason: str = "",
) -> None: ) -> None:
try: try:
@ -291,8 +291,8 @@ class TestCaseFunction(Function):
def addUnexpectedSuccess( def addUnexpectedSuccess(
self, self,
testcase: "unittest.TestCase", testcase: unittest.TestCase,
reason: Optional["twisted.trial.unittest.Todo"] = None, reason: twisted.trial.unittest.Todo | None = None,
) -> None: ) -> None:
msg = "Unexpected success" msg = "Unexpected success"
if reason: if reason:
@ -303,13 +303,13 @@ class TestCaseFunction(Function):
except fail.Exception: except fail.Exception:
self._addexcinfo(sys.exc_info()) self._addexcinfo(sys.exc_info())
def addSuccess(self, testcase: "unittest.TestCase") -> None: def addSuccess(self, testcase: unittest.TestCase) -> None:
pass pass
def stopTest(self, testcase: "unittest.TestCase") -> None: def stopTest(self, testcase: unittest.TestCase) -> None:
pass pass
def addDuration(self, testcase: "unittest.TestCase", elapsed: float) -> None: def addDuration(self, testcase: unittest.TestCase, elapsed: float) -> None:
pass pass
def runtest(self) -> None: def runtest(self) -> None:

View File

@ -1,16 +1,21 @@
from __future__ import annotations
import sys import sys
import traceback import traceback
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator from typing import Generator
from typing import Optional from typing import TYPE_CHECKING
from typing import Type
import warnings import warnings
import pytest import pytest
if TYPE_CHECKING:
from typing_extensions import Self
# Copied from cpython/Lib/test/support/__init__.py, with modifications. # Copied from cpython/Lib/test/support/__init__.py, with modifications.
class catch_unraisable_exception: class catch_unraisable_exception:
"""Context manager catching unraisable exception using sys.unraisablehook. """Context manager catching unraisable exception using sys.unraisablehook.
@ -34,24 +39,24 @@ class catch_unraisable_exception:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.unraisable: Optional["sys.UnraisableHookArgs"] = None self.unraisable: sys.UnraisableHookArgs | None = None
self._old_hook: Optional[Callable[["sys.UnraisableHookArgs"], Any]] = None self._old_hook: Callable[[sys.UnraisableHookArgs], Any] | None = None
def _hook(self, unraisable: "sys.UnraisableHookArgs") -> None: def _hook(self, unraisable: sys.UnraisableHookArgs) -> None:
# Storing unraisable.object can resurrect an object which is being # Storing unraisable.object can resurrect an object which is being
# finalized. Storing unraisable.exc_value creates a reference cycle. # finalized. Storing unraisable.exc_value creates a reference cycle.
self.unraisable = unraisable self.unraisable = unraisable
def __enter__(self) -> "catch_unraisable_exception": def __enter__(self) -> Self:
self._old_hook = sys.unraisablehook self._old_hook = sys.unraisablehook
sys.unraisablehook = self._hook sys.unraisablehook = self._hook
return self return self
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: type[BaseException] | None,
exc_val: Optional[BaseException], exc_val: BaseException | None,
exc_tb: Optional[TracebackType], exc_tb: TracebackType | None,
) -> None: ) -> None:
assert self._old_hook is not None assert self._old_hook is not None
sys.unraisablehook = self._old_hook sys.unraisablehook = self._old_hook

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
from types import FunctionType from types import FunctionType
from typing import Any from typing import Any
from typing import final from typing import final
from typing import Generic from typing import Generic
from typing import Type
from typing import TypeVar from typing import TypeVar
import warnings import warnings
@ -72,7 +73,7 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
__module__ = "pytest" __module__ = "pytest"
@classmethod @classmethod
def simple(cls, apiname: str) -> "PytestExperimentalApiWarning": def simple(cls, apiname: str) -> PytestExperimentalApiWarning:
return cls(f"{apiname} is an experimental api that may change over time") return cls(f"{apiname} is an experimental api that may change over time")
@ -132,7 +133,7 @@ class UnformattedWarning(Generic[_W]):
as opposed to a direct message. as opposed to a direct message.
""" """
category: Type["_W"] category: type[_W]
template: str template: str
def format(self, **kwargs: Any) -> _W: def format(self, **kwargs: Any) -> _W:

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
import sys import sys
from typing import Generator from typing import Generator
from typing import Literal from typing import Literal
from typing import Optional
import warnings import warnings
from _pytest.config import apply_warning_filters from _pytest.config import apply_warning_filters
@ -28,7 +29,7 @@ def catch_warnings_for_item(
config: Config, config: Config,
ihook, ihook,
when: Literal["config", "collect", "runtest"], when: Literal["config", "collect", "runtest"],
item: Optional[Item], item: Item | None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Context manager that catches warnings generated in the contained execution block. """Context manager that catches warnings generated in the contained execution block.
@ -142,7 +143,7 @@ def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
@pytest.hookimpl(wrapper=True) @pytest.hookimpl(wrapper=True)
def pytest_load_initial_conftests( def pytest_load_initial_conftests(
early_config: "Config", early_config: Config,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
with catch_warnings_for_item( with catch_warnings_for_item(
config=early_config, ihook=early_config.hook, when="config", item=None config=early_config, ihook=early_config.hook, when="config", item=None

View File

@ -1,6 +1,8 @@
# shim for pylib going away # shim for pylib going away
# if pylib is installed this file will get skipped # if pylib is installed this file will get skipped
# (`py/__init__.py` has higher precedence) # (`py/__init__.py` has higher precedence)
from __future__ import annotations
import sys import sys
import _pytest._py.error as error import _pytest._py.error as error

View File

@ -1,6 +1,8 @@
# PYTHON_ARGCOMPLETE_OK # PYTHON_ARGCOMPLETE_OK
"""pytest: unit and functional testing with Python.""" """pytest: unit and functional testing with Python."""
from __future__ import annotations
from _pytest import __version__ from _pytest import __version__
from _pytest import version_tuple from _pytest import version_tuple
from _pytest._code import ExceptionInfo from _pytest._code import ExceptionInfo

Some files were not shown because too many files have changed in this diff Show More