Improve typing

Fix #1057
This commit is contained in:
Ran Benita 2023-04-18 23:06:34 +03:00
parent 5dfc590a8c
commit b78cf1c0ce
24 changed files with 858 additions and 477 deletions

View File

@ -32,3 +32,4 @@ repos:
- pytest>=7.0.0
- execnet>=2.1.0
- types-psutil
- setproctitle

1
changelog/1057.trivial Normal file
View File

@ -0,0 +1 @@
The internals of pytest-xdist are now fully typed. The typing is not exposed yet.

View File

@ -137,19 +137,11 @@ lines-after-imports = 2
[tool.mypy]
mypy_path = ["src"]
files = ["src", "testing"]
# TODO: Enable this & fix errors.
# check_untyped_defs = true
disallow_any_generics = true
ignore_missing_imports = true
no_implicit_optional = true
show_error_codes = true
strict_equality = true
warn_redundant_casts = true
warn_return_any = true
strict = true
warn_unreachable = true
warn_unused_configs = true
# TODO: Enable this & fix errors.
# no_implicit_reexport = true
[[tool.mypy.overrides]]
module = ["xdist._version"]
ignore_missing_imports = true
[tool.towncrier]

View File

@ -5,11 +5,15 @@ from enum import Enum
from queue import Empty
from queue import Queue
import sys
from typing import Any
from typing import Sequence
import warnings
import execnet
import pytest
from xdist.remote import Producer
from xdist.remote import WorkerInfo
from xdist.scheduler import EachScheduling
from xdist.scheduler import LoadFileScheduling
from xdist.scheduler import LoadGroupScheduling
@ -18,6 +22,7 @@ from xdist.scheduler import LoadScopeScheduling
from xdist.scheduler import Scheduling
from xdist.scheduler import WorkStealingScheduling
from xdist.workermanage import NodeManager
from xdist.workermanage import WorkerController
class Interrupted(KeyboardInterrupt):
@ -38,29 +43,31 @@ class DSession:
it will wait for instructions.
"""
def __init__(self, config):
shouldstop: bool | str
def __init__(self, config: pytest.Config) -> None:
self.config = config
self.log = Producer("dsession", enabled=config.option.debug)
self.nodemanager = None
self.sched = None
self.nodemanager: NodeManager | None = None
self.sched: Scheduling | None = None
self.shuttingdown = False
self.countfailures = 0
self.maxfail = config.getvalue("maxfail")
self.queue = Queue()
self._session = None
self._failed_collection_errors = {}
self._active_nodes = set()
self.maxfail: int = config.getvalue("maxfail")
self.queue: Queue[tuple[str, dict[str, Any]]] = Queue()
self._session: pytest.Session | None = None
self._failed_collection_errors: dict[object, bool] = {}
self._active_nodes: set[WorkerController] = set()
self._failed_nodes_count = 0
self._max_worker_restart = get_default_max_worker_restart(self.config)
# summary message to print at the end of the session
self._summary_report = None
self._summary_report: str | None = None
self.terminal = config.pluginmanager.getplugin("terminalreporter")
if self.terminal:
self.trdist = TerminalDistReporter(config)
config.pluginmanager.register(self.trdist, "terminaldistreporter")
@property
def session_finished(self):
def session_finished(self) -> bool:
"""Return True if the distributed session has finished.
This means all nodes have executed all test items. This is
@ -68,12 +75,12 @@ class DSession:
"""
return bool(self.shuttingdown and not self._active_nodes)
def report_line(self, line):
def report_line(self, line: str) -> None:
if self.terminal and self.config.option.verbose >= 0:
self.terminal.write_line(line)
@pytest.hookimpl(trylast=True)
def pytest_sessionstart(self, session):
def pytest_sessionstart(self, session: pytest.Session) -> None:
"""Creates and starts the nodes.
The nodes are setup to put their events onto self.queue. As
@ -85,7 +92,7 @@ class DSession:
self._session = session
@pytest.hookimpl
def pytest_sessionfinish(self, session):
def pytest_sessionfinish(self) -> None:
"""Shutdown all nodes."""
nm = getattr(self, "nodemanager", None) # if not fully initialized
if nm is not None:
@ -93,12 +100,16 @@ class DSession:
self._session = None
@pytest.hookimpl
def pytest_collection(self):
def pytest_collection(self) -> bool:
# prohibit collection of test items in controller process
return True
@pytest.hookimpl(trylast=True)
def pytest_xdist_make_scheduler(self, config, log) -> Scheduling | None:
def pytest_xdist_make_scheduler(
self,
config: pytest.Config,
log: Producer,
) -> Scheduling | None:
dist = config.getvalue("dist")
if dist == "each":
return EachScheduling(config, log)
@ -115,7 +126,7 @@ class DSession:
return None
@pytest.hookimpl
def pytest_runtestloop(self):
def pytest_runtestloop(self) -> bool:
self.sched = self.config.hook.pytest_xdist_make_scheduler(
config=self.config, log=self.log
)
@ -132,7 +143,7 @@ class DSession:
raise pending_exception
return True
def loop_once(self):
def loop_once(self) -> None:
"""Process one callback from one of the workers."""
while 1:
if not self._active_nodes:
@ -150,6 +161,7 @@ class DSession:
call = getattr(self, method)
self.log("calling method", method, kwargs)
call(**kwargs)
assert self.sched is not None
if self.sched.tests_finished:
self.triggershutdown()
@ -157,7 +169,11 @@ class DSession:
# callbacks for processing events from workers
#
def worker_workerready(self, node, workerinfo):
def worker_workerready(
self,
node: WorkerController,
workerinfo: WorkerInfo,
) -> None:
"""Emitted when a node first starts up.
This adds the node to the scheduler, nodes continue with
@ -171,9 +187,10 @@ class DSession:
if self.shuttingdown:
node.shutdown()
else:
assert self.sched is not None
self.sched.add_node(node)
def worker_workerfinished(self, node):
def worker_workerfinished(self, node: WorkerController) -> None:
"""Emitted when node executes its pytest_sessionfinish hook.
Removes the node from the scheduler.
@ -194,12 +211,15 @@ class DSession:
self.shouldstop = shouldx
break
else:
assert self.sched is not None
if node in self.sched.nodes:
crashitem = self.sched.remove_node(node)
assert not crashitem, (crashitem, node)
self._active_nodes.remove(node)
def worker_internal_error(self, node, formatted_error):
def worker_internal_error(
self, node: WorkerController, formatted_error: str
) -> None:
"""
pytest_internalerror() was called on the worker.
@ -215,9 +235,10 @@ class DSession:
excrepr = excinfo.getrepr()
self.config.hook.pytest_internalerror(excrepr=excrepr, excinfo=excinfo)
def worker_errordown(self, node, error):
def worker_errordown(self, node: WorkerController, error: object | None) -> None:
"""Emitted by the WorkerController when a node dies."""
self.config.hook.pytest_testnodedown(node=node, error=error)
assert self.sched is not None
try:
crashitem = self.sched.remove_node(node)
except KeyError:
@ -235,7 +256,7 @@ class DSession:
if self._max_worker_restart == 0:
msg = f"worker {node.gateway.id} crashed and worker restarting disabled"
else:
msg = "maximum crashed workers reached: %d" % self._max_worker_restart
msg = f"maximum crashed workers reached: {self._max_worker_restart}"
self._summary_report = msg
self.report_line("\n" + msg)
self.triggershutdown()
@ -246,11 +267,13 @@ class DSession:
self._active_nodes.remove(node)
@pytest.hookimpl
def pytest_terminal_summary(self, terminalreporter):
def pytest_terminal_summary(self, terminalreporter: Any) -> None:
if self.config.option.verbose >= 0 and self._summary_report:
terminalreporter.write_sep("=", f"xdist: {self._summary_report}")
def worker_collectionfinish(self, node, ids):
def worker_collectionfinish(
self, node: WorkerController, ids: Sequence[str]
) -> None:
"""Worker has finished test collection.
This adds the collection for this node to the scheduler. If
@ -264,7 +287,9 @@ class DSession:
self.config.hook.pytest_xdist_node_collection_finished(node=node, ids=ids)
# tell session which items were effectively collected otherwise
# the controller node will finish the session with EXIT_NOTESTSCOLLECTED
assert self._session is not None
self._session.testscollected = len(ids)
assert self.sched is not None
self.sched.add_node_collection(node, ids)
if self.terminal:
self.trdist.setstatus(
@ -280,29 +305,44 @@ class DSession:
)
self.sched.schedule()
def worker_logstart(self, node, nodeid, location):
def worker_logstart(
self,
node: WorkerController,
nodeid: str,
location: tuple[str, int | None, str],
) -> None:
"""Emitted when a node calls the pytest_runtest_logstart hook."""
self.config.hook.pytest_runtest_logstart(nodeid=nodeid, location=location)
def worker_logfinish(self, node, nodeid, location):
def worker_logfinish(
self,
node: WorkerController,
nodeid: str,
location: tuple[str, int | None, str],
) -> None:
"""Emitted when a node calls the pytest_runtest_logfinish hook."""
self.config.hook.pytest_runtest_logfinish(nodeid=nodeid, location=location)
def worker_testreport(self, node, rep):
def worker_testreport(self, node: WorkerController, rep: pytest.TestReport) -> None:
"""Emitted when a node calls the pytest_runtest_logreport hook."""
rep.node = node
rep.node = node # type: ignore[attr-defined]
self.config.hook.pytest_runtest_logreport(report=rep)
self._handlefailures(rep)
def worker_runtest_protocol_complete(self, node, item_index, duration):
def worker_runtest_protocol_complete(
self, node: WorkerController, item_index: int, duration: float
) -> None:
"""
Emitted when a node fires the 'runtest_protocol_complete' event,
signalling that a test has completed the runtestprotocol and should be
removed from the pending list in the scheduler.
"""
assert self.sched is not None
self.sched.mark_test_complete(node, item_index, duration)
def worker_unscheduled(self, node, indices):
def worker_unscheduled(
self, node: WorkerController, indices: Sequence[int]
) -> None:
"""
Emitted when a node fires the 'unscheduled' event, signalling that
some tests have been removed from the worker's queue and should be
@ -311,9 +351,14 @@ class DSession:
This should happen only in response to 'steal' command, so schedulers
not using 'steal' command don't have to implement it.
"""
assert self.sched is not None
self.sched.remove_pending_tests_from_node(node, indices)
def worker_collectreport(self, node, rep):
def worker_collectreport(
self,
node: WorkerController,
rep: pytest.CollectReport | pytest.TestReport,
) -> None:
"""Emitted when a node calls the pytest_collectreport hook.
Because we only need the report when there's a failure/skip, as optimization
@ -322,14 +367,20 @@ class DSession:
assert not rep.passed
self._failed_worker_collectreport(node, rep)
def worker_warning_recorded(self, warning_message, when, nodeid, location):
def worker_warning_recorded(
self,
warning_message: warnings.WarningMessage,
when: str,
nodeid: str,
location: tuple[str, int, str] | None,
) -> None:
"""Emitted when a node calls the pytest_warning_recorded hook."""
kwargs = dict(
warning_message=warning_message, when=when, nodeid=nodeid, location=location
)
self.config.hook.pytest_warning_recorded.call_historic(kwargs=kwargs)
def _clone_node(self, node):
def _clone_node(self, node: WorkerController) -> WorkerController:
"""Return new node based on an existing one.
This is normally for when a node dies, this will copy the spec
@ -339,12 +390,17 @@ class DSession:
"""
spec = node.gateway.spec
spec.id = None
assert self.nodemanager is not None
self.nodemanager.group.allocate_id(spec)
node = self.nodemanager.setup_node(spec, self.queue.put)
self._active_nodes.add(node)
return node
clone = self.nodemanager.setup_node(spec, self.queue.put)
self._active_nodes.add(clone)
return clone
def _failed_worker_collectreport(self, node, rep):
def _failed_worker_collectreport(
self,
node: WorkerController,
rep: pytest.CollectReport | pytest.TestReport,
) -> None:
# Check we haven't already seen this report (from
# another worker).
if rep.longrepr not in self._failed_collection_errors:
@ -352,7 +408,10 @@ class DSession:
self.config.hook.pytest_collectreport(report=rep)
self._handlefailures(rep)
def _handlefailures(self, rep):
def _handlefailures(
self,
rep: pytest.CollectReport | pytest.TestReport,
) -> None:
if rep.failed:
self.countfailures += 1
if (
@ -362,22 +421,28 @@ class DSession:
):
self.shouldstop = f"stopping after {self.countfailures} failures"
def triggershutdown(self):
def triggershutdown(self) -> None:
if not self.shuttingdown:
self.log("triggering shutdown")
self.shuttingdown = True
assert self.sched is not None
for node in self.sched.nodes:
node.shutdown()
def handle_crashitem(self, nodeid, worker):
def handle_crashitem(self, nodeid: str, worker: WorkerController) -> None:
# XXX get more reporting info by recording pytest_runtest_logstart?
# XXX count no of failures and retry N times
fspath = nodeid.split("::")[0]
msg = f"worker {worker.gateway.id!r} crashed while running {nodeid!r}"
rep = pytest.TestReport(
nodeid, (fspath, None, fspath), (), "failed", msg, "???"
nodeid=nodeid,
location=(fspath, None, fspath),
keywords={},
outcome="failed",
longrepr=msg,
when="???", # type: ignore[arg-type]
)
rep.node = worker
rep.node = worker # type: ignore[attr-defined]
self.config.hook.pytest_handlecrashitem(
crashitem=nodeid,
@ -404,10 +469,10 @@ class WorkerStatus(Enum):
class TerminalDistReporter:
def __init__(self, config) -> None:
def __init__(self, config: pytest.Config) -> None:
self.config = config
self.tr = config.pluginmanager.getplugin("terminalreporter")
self._status: dict[str, tuple[WorkerStatus, int]] = {}
self._status: dict[object, tuple[WorkerStatus, int]] = {}
self._lastlen = 0
self._isatty = getattr(self.tr, "isatty", self.tr.hasmarkup)
@ -419,7 +484,12 @@ class TerminalDistReporter:
self.write_line(self.getstatus())
def setstatus(
self, spec, status: WorkerStatus, *, tests_collected: int, show: bool = True
self,
spec: execnet.XSpec,
status: WorkerStatus,
*,
tests_collected: int,
show: bool = True,
) -> None:
self._status[spec.id] = (status, tests_collected)
if show and self._isatty:
@ -433,7 +503,7 @@ class TerminalDistReporter:
return "bringing up nodes..."
def rewrite(self, line, newline=False):
def rewrite(self, line: str, newline: bool = False) -> None:
pline = line + " " * max(self._lastlen - len(line), 0)
if newline:
self._lastlen = 0
@ -443,7 +513,7 @@ class TerminalDistReporter:
self.tr.rewrite(pline, bold=True)
@pytest.hookimpl
def pytest_xdist_setupnodes(self, specs) -> None:
def pytest_xdist_setupnodes(self, specs: Sequence[execnet.XSpec]) -> None:
self._specs = specs
for spec in specs:
self.setstatus(spec, WorkerStatus.Created, tests_collected=0, show=False)
@ -451,7 +521,7 @@ class TerminalDistReporter:
self.ensure_show_status()
@pytest.hookimpl
def pytest_xdist_newgateway(self, gateway) -> None:
def pytest_xdist_newgateway(self, gateway: execnet.Gateway) -> None:
if self.config.option.verbose > 0:
rinfo = gateway._rinfo()
different_interpreter = rinfo.executable != sys.executable
@ -464,7 +534,7 @@ class TerminalDistReporter:
self.setstatus(gateway.spec, WorkerStatus.Initialized, tests_collected=0)
@pytest.hookimpl
def pytest_testnodeready(self, node) -> None:
def pytest_testnodeready(self, node: WorkerController) -> None:
if self.config.option.verbose > 0:
d = node.workerinfo
different_interpreter = d.get("executable") != sys.executable
@ -476,23 +546,25 @@ class TerminalDistReporter:
)
@pytest.hookimpl
def pytest_testnodedown(self, node, error) -> None:
def pytest_testnodedown(self, node: WorkerController, error: object) -> None:
if not error:
return
self.write_line(f"[{node.gateway.id}] node down: {error}")
def get_default_max_worker_restart(config):
def get_default_max_worker_restart(config: pytest.Config) -> int | None:
"""Gets the default value of --max-worker-restart option if it is not provided.
Use a reasonable default to avoid workers from restarting endlessly due to crashing collections (#226).
"""
result = config.option.maxworkerrestart
if result is not None:
result = int(result)
result_str: str | None = config.option.maxworkerrestart
if result_str is not None:
result = int(result_str)
elif config.option.numprocesses:
# if --max-worker-restart was not provided, use a reasonable default (#226)
result = config.option.numprocesses * 4
else:
result = None
return result

View File

@ -13,6 +13,7 @@ import os
from pathlib import Path
import sys
import time
from typing import Any
from typing import Sequence
from _pytest._io import TerminalWriter
@ -23,7 +24,7 @@ from xdist._path import visit_path
@pytest.hookimpl
def pytest_addoption(parser):
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("xdist", "distributed and subprocess testing")
group._addoption(
"-f",
@ -37,13 +38,14 @@ def pytest_addoption(parser):
@pytest.hookimpl
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: pytest.Config) -> int | None:
if config.getoption("looponfail"):
usepdb = config.getoption("usepdb", False) # a core option
if usepdb:
raise pytest.UsageError("--pdb is incompatible with --looponfail.")
looponfail_main(config)
return 2 # looponfail only can get stop with ctrl-C anyway
return None
def looponfail_main(config: pytest.Config) -> None:
@ -68,19 +70,21 @@ def looponfail_main(config: pytest.Config) -> None:
class RemoteControl:
def __init__(self, config):
self.config = config
self.failures = []
gateway: execnet.Gateway
def trace(self, *args):
def __init__(self, config: pytest.Config) -> None:
self.config = config
self.failures: list[str] = []
def trace(self, *args: object) -> None:
if self.config.option.debug:
msg = " ".join(str(x) for x in args)
print("RemoteControl:", msg)
def initgateway(self):
def initgateway(self) -> execnet.Gateway:
return execnet.makegateway("popen")
def setup(self):
def setup(self) -> None:
if hasattr(self, "gateway"):
raise ValueError("already have gateway %r" % self.gateway)
self.trace("setting up worker session")
@ -90,17 +94,17 @@ class RemoteControl:
args=self.config.args,
option_dict=vars(self.config.option),
)
remote_outchannel = channel.receive()
remote_outchannel: execnet.Channel = channel.receive()
out = TerminalWriter()
def write(s):
def write(s: str) -> None:
out._file.write(s)
out._file.flush()
remote_outchannel.setcallback(write)
def ensure_teardown(self):
def ensure_teardown(self) -> None:
if hasattr(self, "channel"):
if not self.channel.isclosed():
self.trace("closing", self.channel)
@ -111,12 +115,12 @@ class RemoteControl:
self.gateway.exit()
del self.gateway
def runsession(self):
def runsession(self) -> tuple[list[str], list[str], bool]:
try:
self.trace("sending", self.failures)
self.channel.send(self.failures)
try:
return self.channel.receive()
return self.channel.receive() # type: ignore[no-any-return]
except self.channel.RemoteError:
e = sys.exc_info()[1]
self.trace("ERROR", e)
@ -124,7 +128,7 @@ class RemoteControl:
finally:
self.ensure_teardown()
def loop_once(self):
def loop_once(self) -> None:
self.setup()
self.wasfailing = self.failures and len(self.failures)
result = self.runsession()
@ -139,7 +143,9 @@ class RemoteControl:
self.failures = uniq_failures
def repr_pytest_looponfailinfo(failreports, rootdirs):
def repr_pytest_looponfailinfo(
failreports: Sequence[str], rootdirs: Sequence[Path]
) -> None:
tr = TerminalWriter()
if failreports:
tr.sep("#", "LOOPONFAILING", bold=True)
@ -151,12 +157,16 @@ def repr_pytest_looponfailinfo(failreports, rootdirs):
tr.line(f"### Watching: {rootdir}", bold=True)
def init_worker_session(channel, args, option_dict):
def init_worker_session(
channel: "execnet.Channel", # noqa: UP037
args: list[str],
option_dict: dict[str, "Any"], # noqa: UP037
) -> None:
import os
import sys
outchannel = channel.gateway.newchannel()
sys.stdout = sys.stderr = outchannel.makefile("w")
sys.stdout = sys.stderr = outchannel.makefile("w") # type: ignore[assignment]
channel.send(outchannel)
# prune sys.path to not contain relative paths
newpaths = []
@ -179,21 +189,21 @@ def init_worker_session(channel, args, option_dict):
class WorkerFailSession:
def __init__(self, config, channel):
def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
self.config = config
self.channel = channel
self.recorded_failures = []
self.recorded_failures: list[pytest.CollectReport | pytest.TestReport] = []
self.collection_failed = False
config.pluginmanager.register(self)
config.option.looponfail = False
config.option.usepdb = False
def DEBUG(self, *args):
def DEBUG(self, *args: object) -> None:
if self.config.option.debug:
print(" ".join(map(str, args)))
@pytest.hookimpl
def pytest_collection(self, session):
def pytest_collection(self, session: pytest.Session) -> bool:
self.session = session
self.trails = self.current_command
hook = self.session.ihook
@ -208,17 +218,17 @@ class WorkerFailSession:
return True
@pytest.hookimpl
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
if report.failed:
self.recorded_failures.append(report)
@pytest.hookimpl
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: pytest.CollectReport) -> None:
if report.failed:
self.recorded_failures.append(report)
self.collection_failed = True
def main(self):
def main(self) -> None:
self.DEBUG("WORKER: received configuration, waiting for command trails")
try:
command = self.channel.receive()
@ -233,7 +243,8 @@ class WorkerFailSession:
loc = rep.longrepr
loc = str(getattr(loc, "reprcrash", loc))
failreports.append(loc)
self.channel.send((trails, failreports, self.collection_failed))
result = (trails, failreports, self.collection_failed)
self.channel.send(result)
class StatRecorder:
@ -248,7 +259,7 @@ class StatRecorder:
def rec(self, p: Path) -> bool:
return not p.name.startswith(".") and p.exists()
def waitonchange(self, checkinterval=1.0):
def waitonchange(self, checkinterval: float = 1.0) -> None:
while 1:
changed = self.check()
if changed:

View File

@ -12,16 +12,32 @@ must be taken in plugins in case ``xdist`` is not installed. Please see:
http://pytest.org/en/latest/writing_plugins.html#optionally-using-hooks-from-3rd-party-plugins
"""
from __future__ import annotations
import os
from typing import Any
from typing import Sequence
from typing import TYPE_CHECKING
import execnet
import pytest
if TYPE_CHECKING:
from xdist.remote import Producer
from xdist.scheduler.protocol import Scheduling
from xdist.workermanage import WorkerController
@pytest.hookspec()
def pytest_xdist_setupnodes(config, specs):
def pytest_xdist_setupnodes(
config: pytest.Config, specs: Sequence[execnet.XSpec]
) -> None:
"""Called before any remote node is set up."""
@pytest.hookspec()
def pytest_xdist_newgateway(gateway):
def pytest_xdist_newgateway(gateway: execnet.Gateway) -> None:
"""Called on new raw gateway creation."""
@ -30,7 +46,10 @@ def pytest_xdist_newgateway(gateway):
"rsync feature is deprecated and will be removed in pytest-xdist 4.0"
)
)
def pytest_xdist_rsyncstart(source, gateways):
def pytest_xdist_rsyncstart(
source: str | os.PathLike[str],
gateways: Sequence[execnet.Gateway],
) -> None:
"""Called before rsyncing a directory to remote gateways takes place."""
@ -39,52 +58,62 @@ def pytest_xdist_rsyncstart(source, gateways):
"rsync feature is deprecated and will be removed in pytest-xdist 4.0"
)
)
def pytest_xdist_rsyncfinish(source, gateways):
def pytest_xdist_rsyncfinish(
source: str | os.PathLike[str],
gateways: Sequence[execnet.Gateway],
) -> None:
"""Called after rsyncing a directory to remote gateways takes place."""
@pytest.hookspec(firstresult=True)
def pytest_xdist_getremotemodule():
def pytest_xdist_getremotemodule() -> Any:
"""Called when creating remote node."""
@pytest.hookspec()
def pytest_configure_node(node):
def pytest_configure_node(node: WorkerController) -> None:
"""Configure node information before it gets instantiated."""
@pytest.hookspec()
def pytest_testnodeready(node):
def pytest_testnodeready(node: WorkerController) -> None:
"""Test Node is ready to operate."""
@pytest.hookspec()
def pytest_testnodedown(node, error):
def pytest_testnodedown(node: WorkerController, error: object | None) -> None:
"""Test Node is down."""
@pytest.hookspec()
def pytest_xdist_node_collection_finished(node, ids):
def pytest_xdist_node_collection_finished(
node: WorkerController, ids: Sequence[str]
) -> None:
"""Called by the controller node when a worker node finishes collecting."""
@pytest.hookspec(firstresult=True)
def pytest_xdist_make_scheduler(config, log):
def pytest_xdist_make_scheduler(
config: pytest.Config, log: Producer
) -> Scheduling | None:
"""Return a node scheduler implementation."""
@pytest.hookspec(firstresult=True)
def pytest_xdist_auto_num_workers(config):
def pytest_xdist_auto_num_workers(config: pytest.Config) -> int:
"""
Return the number of workers to spawn when ``--numprocesses=auto`` is given in the
command-line.
.. versionadded:: 2.1
"""
raise NotImplementedError()
@pytest.hookspec(firstresult=True)
def pytest_handlecrashitem(crashitem, report, sched):
def pytest_handlecrashitem(
crashitem: str, report: pytest.TestReport, sched: Scheduling
) -> None:
"""
Handle a crashitem, modifying the report if necessary.

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import os
import sys
from typing import Literal
import uuid
import warnings
@ -10,7 +13,7 @@ _sys_path = list(sys.path) # freeze a copy of sys.path at interpreter startup
@pytest.hookimpl
def pytest_xdist_auto_num_workers(config):
def pytest_xdist_auto_num_workers(config: pytest.Config) -> int:
env_var = os.environ.get("PYTEST_XDIST_AUTO_NUM_WORKERS")
if env_var:
try:
@ -25,14 +28,14 @@ def pytest_xdist_auto_num_workers(config):
except ImportError:
pass
else:
use_logical = config.option.numprocesses == "logical"
use_logical: bool = config.option.numprocesses == "logical"
count = psutil.cpu_count(logical=use_logical) or psutil.cpu_count()
if count:
return count
try:
from os import sched_getaffinity
def cpu_count():
def cpu_count() -> int:
return len(sched_getaffinity(0))
except ImportError:
@ -40,7 +43,7 @@ def pytest_xdist_auto_num_workers(config):
# workaround https://bitbucket.org/pypy/pypy/issues/2375
return 2
try:
from os import cpu_count
from os import cpu_count # type: ignore[assignment]
except ImportError:
from multiprocessing import cpu_count
try:
@ -50,15 +53,15 @@ def pytest_xdist_auto_num_workers(config):
return n if n else 1
def parse_numprocesses(s):
def parse_numprocesses(s: str) -> int | Literal["auto", "logical"]:
if s in ("auto", "logical"):
return s
return s # type: ignore[return-value]
elif s is not None:
return int(s)
@pytest.hookimpl
def pytest_addoption(parser):
def pytest_addoption(parser: pytest.Parser) -> None:
# 'Help' formatting (same rules as pytest's):
# Start with capitalized letters.
# If a single phrase, do not end with period. If more than one phrase, all phrases end with periods.
@ -206,7 +209,7 @@ def pytest_addoption(parser):
@pytest.hookimpl
def pytest_addhooks(pluginmanager):
def pytest_addhooks(pluginmanager: pytest.PytestPluginManager) -> None:
from xdist import newhooks
pluginmanager.add_hookspecs(newhooks)
@ -218,7 +221,7 @@ def pytest_addhooks(pluginmanager):
@pytest.hookimpl(trylast=True)
def pytest_configure(config):
def pytest_configure(config: pytest.Config) -> None:
config_line = (
"xdist_group: specify group for tests should run in same session."
"in relation to one another. Provided by pytest-xdist."
@ -256,16 +259,13 @@ def pytest_configure(config):
config.issue_config_time_warning(warning, 2)
def _is_distribution_mode(config):
"""Return `True` if distribution mode is on, `False` otherwise.
:param config: the `pytest` `config` object
"""
return config.getoption("dist") != "no" and config.getoption("tx")
def _is_distribution_mode(config: pytest.Config) -> bool:
"""Whether distribution mode is on."""
return config.getoption("dist") != "no" and bool(config.getoption("tx"))
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: pytest.Config) -> None:
if config.option.distload:
config.option.dist = "load"
@ -302,7 +302,9 @@ def pytest_cmdline_main(config):
# -------------------------------------------------------------------------
def is_xdist_worker(request_or_session) -> bool:
def is_xdist_worker(
request_or_session: pytest.FixtureRequest | pytest.Session,
) -> bool:
"""Return `True` if this is an xdist worker, `False` otherwise.
:param request_or_session: the `pytest` `request` or `session` object
@ -310,7 +312,9 @@ def is_xdist_worker(request_or_session) -> bool:
return hasattr(request_or_session.config, "workerinput")
def is_xdist_controller(request_or_session) -> bool:
def is_xdist_controller(
request_or_session: pytest.FixtureRequest | pytest.Session,
) -> bool:
"""Return `True` if this is the xdist controller, `False` otherwise.
Note: this method also returns `False` when distribution has not been
@ -328,7 +332,9 @@ def is_xdist_controller(request_or_session) -> bool:
is_xdist_master = is_xdist_controller
def get_xdist_worker_id(request_or_session):
def get_xdist_worker_id(
request_or_session: pytest.FixtureRequest | pytest.Session,
) -> str:
"""Return the id of the current worker ('gw0', 'gw1', etc) or 'master'
if running on the controller node.
@ -338,14 +344,15 @@ def get_xdist_worker_id(request_or_session):
:param request_or_session: the `pytest` `request` or `session` object
"""
if hasattr(request_or_session.config, "workerinput"):
return request_or_session.config.workerinput["workerid"]
workerid: str = request_or_session.config.workerinput["workerid"]
return workerid
else:
# TODO: remove "master", ideally for a None
return "master"
@pytest.fixture(scope="session")
def worker_id(request):
def worker_id(request: pytest.FixtureRequest) -> str:
"""Return the id of the current worker ('gw0', 'gw1', etc) or 'master'
if running on the master node.
"""
@ -354,9 +361,10 @@ def worker_id(request):
@pytest.fixture(scope="session")
def testrun_uid(request):
def testrun_uid(request: pytest.FixtureRequest) -> str:
"""Return the unique id of the current test."""
if hasattr(request.config, "workerinput"):
return request.config.workerinput["testrunuid"]
testrunid: str = request.config.workerinput["testrunuid"]
return testrunid
else:
return uuid.uuid4().hex

View File

@ -6,16 +6,22 @@ on the rest of the xdist code. This means that the xdist-plugin
needs not to be installed in remote environments.
"""
from __future__ import annotations
import contextlib
import enum
import os
import sys
import time
from typing import Any
from typing import Generator
from typing import Literal
from typing import Sequence
from typing import TypedDict
import warnings
from _pytest.config import _prepareconfig
from execnet.gateway_base import DumpError
from execnet.gateway_base import dumps
import execnet
import pytest
@ -23,7 +29,7 @@ try:
from setproctitle import setproctitle
except ImportError:
def setproctitle(title):
def setproctitle(title: str) -> None:
pass
@ -35,7 +41,7 @@ class Producer:
to have the other way around.
"""
def __init__(self, name: str, *, enabled: bool = True):
def __init__(self, name: str, *, enabled: bool = True) -> None:
self.name = name
self.enabled = enabled
@ -46,11 +52,11 @@ class Producer:
if self.enabled:
print(f"[{self.name}]", *a, **k, file=sys.stderr)
def __getattr__(self, name: str) -> "Producer":
def __getattr__(self, name: str) -> Producer:
return type(self)(name, enabled=self.enabled)
def worker_title(title):
def worker_title(title: str) -> None:
try:
setproctitle(title)
except Exception:
@ -64,59 +70,63 @@ class Marker(enum.Enum):
class WorkerInteractor:
def __init__(self, config, channel):
def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
self.config = config
self.workerid = config.workerinput.get("workerid", "?")
self.testrunuid = config.workerinput["testrunuid"]
workerinput: dict[str, Any] = config.workerinput # type: ignore[attr-defined]
self.workerid = workerinput.get("workerid", "?")
self.testrunuid = workerinput["testrunuid"]
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
self.channel = channel
self.torun = self._make_queue()
self.nextitem_index = None
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
config.pluginmanager.register(self)
def _make_queue(self):
def _make_queue(self) -> Any:
return self.channel.gateway.execmodel.queue.Queue()
def _get_next_item_index(self):
def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
"""Gets the next item from test queue. Handles the case when the queue
is replaced concurrently in another thread.
"""
result = self.torun.get()
while result is Marker.QUEUE_REPLACED:
result = self.torun.get()
return result
return result # type: ignore[no-any-return]
def sendevent(self, name, **kwargs):
def sendevent(self, name: str, **kwargs: object) -> None:
self.log("sending", name, kwargs)
self.channel.send((name, kwargs))
@pytest.hookimpl
def pytest_internalerror(self, excrepr):
def pytest_internalerror(self, excrepr: object) -> None:
formatted_error = str(excrepr)
for line in formatted_error.split("\n"):
self.log("IERROR>", line)
interactor.sendevent("internal_error", formatted_error=formatted_error)
@pytest.hookimpl
def pytest_sessionstart(self, session):
def pytest_sessionstart(self, session: pytest.Session) -> None:
self.session = session
workerinfo = getinfodict()
self.sendevent("workerready", workerinfo=workerinfo)
@pytest.hookimpl(hookwrapper=True)
def pytest_sessionfinish(self, exitstatus):
def pytest_sessionfinish(self, exitstatus: int) -> Generator[None, object, None]:
workeroutput: dict[str, Any] = self.config.workeroutput # type: ignore[attr-defined]
# in pytest 5.0+, exitstatus is an IntEnum object
self.config.workeroutput["exitstatus"] = int(exitstatus)
self.config.workeroutput["shouldfail"] = self.session.shouldfail
self.config.workeroutput["shouldstop"] = self.session.shouldstop
workeroutput["exitstatus"] = int(exitstatus)
workeroutput["shouldfail"] = self.session.shouldfail
workeroutput["shouldstop"] = self.session.shouldstop
yield
self.sendevent("workerfinished", workeroutput=self.config.workeroutput)
self.sendevent("workerfinished", workeroutput=workeroutput)
@pytest.hookimpl
def pytest_collection(self, session):
def pytest_collection(self) -> None:
self.sendevent("collectionstart")
def handle_command(self, command):
def handle_command(
self, command: tuple[str, dict[str, Any]] | Literal[Marker.SHUTDOWN]
) -> None:
if command is Marker.SHUTDOWN:
self.torun.put(Marker.SHUTDOWN)
return
@ -135,18 +145,19 @@ class WorkerInteractor:
elif name == "steal":
self.steal(kwargs["indices"])
def steal(self, indices):
indices = set(indices)
def steal(self, indices: Sequence[int]) -> None:
indices_set = set(indices)
stolen = []
old_queue, self.torun = self.torun, self._make_queue()
def old_queue_get_nowait_noraise():
def old_queue_get_nowait_noraise() -> int | None:
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
return old_queue.get_nowait()
return old_queue.get_nowait() # type: ignore[no-any-return]
return None
for i in iter(old_queue_get_nowait_noraise, None):
if i in indices:
if i in indices_set:
stolen.append(i)
else:
self.torun.put(i)
@ -155,7 +166,7 @@ class WorkerInteractor:
old_queue.put(Marker.QUEUE_REPLACED)
@pytest.hookimpl
def pytest_runtestloop(self, session):
def pytest_runtestloop(self, session: pytest.Session) -> bool:
self.log("entering main loop")
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
self.nextitem_index = self._get_next_item_index()
@ -165,7 +176,8 @@ class WorkerInteractor:
break
return True
def run_one_test(self):
def run_one_test(self) -> None:
assert isinstance(self.nextitem_index, int)
self.item_index = self.nextitem_index
self.nextitem_index = self._get_next_item_index()
@ -174,6 +186,7 @@ class WorkerInteractor:
if self.nextitem_index is Marker.SHUTDOWN:
nextitem = None
else:
assert self.nextitem_index is not None
nextitem = items[self.nextitem_index]
worker_title("[pytest-xdist running] %s" % item.nodeid)
@ -188,7 +201,11 @@ class WorkerInteractor:
"runtest_protocol_complete", item_index=self.item_index, duration=duration
)
def pytest_collection_modifyitems(self, session, config, items):
def pytest_collection_modifyitems(
self,
config: pytest.Config,
items: list[pytest.Item],
) -> None:
# add the group name to nodeid as suffix if --dist=loadgroup
if config.getvalue("loadgroup"):
for item in items:
@ -203,7 +220,7 @@ class WorkerInteractor:
item._nodeid = f"{item.nodeid}@{gname}"
@pytest.hookimpl
def pytest_collection_finish(self, session):
def pytest_collection_finish(self, session: pytest.Session) -> None:
self.sendevent(
"collectionfinish",
topdir=str(self.config.rootpath),
@ -211,15 +228,23 @@ class WorkerInteractor:
)
@pytest.hookimpl
def pytest_runtest_logstart(self, nodeid, location):
def pytest_runtest_logstart(
self,
nodeid: str,
location: tuple[str, int | None, str],
) -> None:
self.sendevent("logstart", nodeid=nodeid, location=location)
@pytest.hookimpl
def pytest_runtest_logfinish(self, nodeid, location):
def pytest_runtest_logfinish(
self,
nodeid: str,
location: tuple[str, int | None, str],
) -> None:
self.sendevent("logfinish", nodeid=nodeid, location=location)
@pytest.hookimpl
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
data = self.config.hook.pytest_report_to_serializable(
config=self.config, report=report
)
@ -230,7 +255,7 @@ class WorkerInteractor:
self.sendevent("testreport", data=data)
@pytest.hookimpl
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: pytest.CollectReport) -> None:
# send only reports that have not passed to controller as optimization (#330)
if not report.passed:
data = self.config.hook.pytest_report_to_serializable(
@ -239,7 +264,13 @@ class WorkerInteractor:
self.sendevent("collectreport", data=data)
@pytest.hookimpl
def pytest_warning_recorded(self, warning_message, when, nodeid, location):
def pytest_warning_recorded(
self,
warning_message: warnings.WarningMessage,
when: str,
nodeid: str,
location: tuple[str, int, str] | None,
) -> None:
self.sendevent(
"warning_recorded",
warning_message_data=serialize_warning_message(warning_message),
@ -249,7 +280,9 @@ class WorkerInteractor:
)
def serialize_warning_message(warning_message):
def serialize_warning_message(
warning_message: warnings.WarningMessage,
) -> dict[str, Any]:
if isinstance(warning_message.message, Warning):
message_module = type(warning_message.message).__module__
message_class_name = type(warning_message.message).__name__
@ -257,8 +290,8 @@ def serialize_warning_message(warning_message):
# check now if we can serialize the warning arguments (#349)
# if not, we will just use the exception message on the controller node
try:
dumps(warning_message.message.args)
except DumpError:
execnet.dumps(warning_message.message.args)
except execnet.DumpError:
message_args = None
else:
message_args = warning_message.message.args
@ -283,27 +316,38 @@ def serialize_warning_message(warning_message):
"category_class_name": category_class_name,
}
# access private _WARNING_DETAILS because the attributes vary between Python versions
for attr_name in warning_message._WARNING_DETAILS:
for attr_name in warning_message._WARNING_DETAILS: # type: ignore[attr-defined]
if attr_name in ("message", "category"):
continue
attr = getattr(warning_message, attr_name)
# Check if we can serialize the warning detail, marking `None` otherwise
# Note that we need to define the attr (even as `None`) to allow deserializing
try:
dumps(attr)
except DumpError:
execnet.dumps(attr)
except execnet.DumpError:
result[attr_name] = repr(attr)
else:
result[attr_name] = attr
return result
def getinfodict():
class WorkerInfo(TypedDict):
version: str
version_info: tuple[int, int, int, str, int]
sysplatform: str
platform: str
executable: str
cwd: str
id: str
spec: execnet.XSpec
def getinfodict() -> WorkerInfo:
import platform
return dict(
version=sys.version,
version_info=tuple(sys.version_info),
version_info=tuple(sys.version_info), # type: ignore[typeddict-item]
sysplatform=sys.platform,
platform=platform.platform(),
executable=sys.executable,
@ -311,7 +355,7 @@ def getinfodict():
)
def setup_config(config, basetemp):
def setup_config(config: pytest.Config, basetemp: str | None) -> None:
config.option.loadgroup = config.getvalue("dist") == "loadgroup"
config.option.looponfail = False
config.option.usepdb = False
@ -323,7 +367,7 @@ def setup_config(config, basetemp):
if __name__ == "__channelexec__":
channel = channel # type: ignore[name-defined] # noqa: F821, PLW0127
channel: execnet.Channel = channel # type: ignore[name-defined] # noqa: F821, PLW0127
workerinput, args, option_dict, change_sys_path = channel.receive() # type: ignore[name-defined]
if change_sys_path is None:

View File

@ -1,7 +1,15 @@
from __future__ import annotations
from difflib import unified_diff
from typing import Sequence
def report_collection_diff(from_collection, to_collection, from_id, to_id):
def report_collection_diff(
from_collection: Sequence[str],
to_collection: Sequence[str],
from_id: str,
to_id: str,
) -> str | None:
"""Report the collected test difference between two nodes.
:returns: detailed message describing the difference between the given

View File

@ -1,6 +1,13 @@
from __future__ import annotations
from typing import Sequence
import pytest
from xdist.remote import Producer
from xdist.report import report_collection_diff
from xdist.workermanage import parse_spec_config
from xdist.workermanage import WorkerController
class EachScheduling:
@ -17,13 +24,13 @@ class EachScheduling:
assigned the remaining items from the removed node.
"""
def __init__(self, config, log=None):
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
self.config = config
self.numnodes = len(parse_spec_config(config))
self.node2collection = {}
self.node2pending = {}
self._started = []
self._removed2pending = {}
self.node2collection: dict[WorkerController, list[str]] = {}
self.node2pending: dict[WorkerController, list[int]] = {}
self._started: list[WorkerController] = []
self._removed2pending: dict[WorkerController, list[int]] = {}
if log is None:
self.log = Producer("eachsched")
else:
@ -31,12 +38,12 @@ class EachScheduling:
self.collection_is_completed = False
@property
def nodes(self):
def nodes(self) -> list[WorkerController]:
"""A list of all nodes in the scheduler."""
return list(self.node2pending.keys())
@property
def tests_finished(self):
def tests_finished(self) -> bool:
if not self.collection_is_completed:
return False
if self._removed2pending:
@ -47,7 +54,7 @@ class EachScheduling:
return True
@property
def has_pending(self):
def has_pending(self) -> bool:
"""Return True if there are pending test items.
This indicates that collection has finished and nodes are
@ -59,11 +66,13 @@ class EachScheduling:
return True
return False
def add_node(self, node):
def add_node(self, node: WorkerController) -> None:
assert node not in self.node2pending
self.node2pending[node] = []
def add_node_collection(self, node, collection):
def add_node_collection(
self, node: WorkerController, collection: Sequence[str]
) -> None:
"""Add the collected test items from a node.
Collection is complete once all nodes have submitted their
@ -97,26 +106,32 @@ class EachScheduling:
self.node2pending[node] = pending
break
def mark_test_complete(self, node, item_index, duration=0):
def mark_test_complete(
self, node: WorkerController, item_index: int, duration: float = 0
) -> None:
self.node2pending[node].remove(item_index)
def mark_test_pending(self, item):
def mark_test_pending(self, item: str) -> None:
raise NotImplementedError()
def remove_pending_tests_from_node(self, node, indices):
def remove_pending_tests_from_node(
self,
node: WorkerController,
indices: Sequence[int],
) -> None:
raise NotImplementedError()
def remove_node(self, node):
def remove_node(self, node: WorkerController) -> str | None:
# KeyError if we didn't get an add_node() yet
pending = self.node2pending.pop(node)
if not pending:
return
return None
crashitem = self.node2collection[node][pending.pop(0)]
if pending:
self._removed2pending[node] = pending
return crashitem
def schedule(self):
def schedule(self) -> None:
"""Schedule the test items on the nodes.
If the node's pending list is empty it is a new node which

View File

@ -1,10 +1,14 @@
from __future__ import annotations
from itertools import cycle
from typing import Sequence
import pytest
from xdist.remote import Producer
from xdist.report import report_collection_diff
from xdist.workermanage import parse_spec_config
from xdist.workermanage import WorkerController
class LoadScheduling:
@ -53,12 +57,12 @@ class LoadScheduling:
:config: Config object, used for handling hooks.
"""
def __init__(self, config, log=None):
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
self.numnodes = len(parse_spec_config(config))
self.node2collection = {}
self.node2pending = {}
self.pending = []
self.collection = None
self.node2collection: dict[WorkerController, list[str]] = {}
self.node2pending: dict[WorkerController, list[int]] = {}
self.pending: list[int] = []
self.collection: list[str] | None = None
if log is None:
self.log = Producer("loadsched")
else:
@ -67,12 +71,12 @@ class LoadScheduling:
self.maxschedchunk = self.config.getoption("maxschedchunk")
@property
def nodes(self):
def nodes(self) -> list[WorkerController]:
"""A list of all nodes in the scheduler."""
return list(self.node2pending.keys())
@property
def collection_is_completed(self):
def collection_is_completed(self) -> bool:
"""Boolean indication initial test collection is complete.
This is a boolean indicating all initial participating nodes
@ -82,7 +86,7 @@ class LoadScheduling:
return len(self.node2collection) >= self.numnodes
@property
def tests_finished(self):
def tests_finished(self) -> bool:
"""Return True if all tests have been executed by the nodes."""
if not self.collection_is_completed:
return False
@ -94,7 +98,7 @@ class LoadScheduling:
return True
@property
def has_pending(self):
def has_pending(self) -> bool:
"""Return True if there are pending test items.
This indicates that collection has finished and nodes are
@ -108,7 +112,7 @@ class LoadScheduling:
return True
return False
def add_node(self, node):
def add_node(self, node: WorkerController) -> None:
"""Add a new node to the scheduler.
From now on the node will be allocated chunks of tests to
@ -120,7 +124,9 @@ class LoadScheduling:
assert node not in self.node2pending
self.node2pending[node] = []
def add_node_collection(self, node, collection):
def add_node_collection(
self, node: WorkerController, collection: Sequence[str]
) -> None:
"""Add the collected test items from a node.
The collection is stored in the ``.node2collection`` map.
@ -141,7 +147,9 @@ class LoadScheduling:
return
self.node2collection[node] = list(collection)
def mark_test_complete(self, node, item_index, duration=0):
def mark_test_complete(
self, node: WorkerController, item_index: int, duration: float = 0
) -> None:
"""Mark test item as completed by node.
The duration it took to execute the item is used as a hint to
@ -152,7 +160,8 @@ class LoadScheduling:
self.node2pending[node].remove(item_index)
self.check_schedule(node, duration=duration)
def mark_test_pending(self, item):
def mark_test_pending(self, item: str) -> None:
assert self.collection is not None
self.pending.insert(
0,
self.collection.index(item),
@ -160,10 +169,14 @@ class LoadScheduling:
for node in self.node2pending:
self.check_schedule(node)
def remove_pending_tests_from_node(self, node, indices):
def remove_pending_tests_from_node(
self,
node: WorkerController,
indices: Sequence[int],
) -> None:
raise NotImplementedError()
def check_schedule(self, node, duration=0):
def check_schedule(self, node: WorkerController, duration: float = 0) -> None:
"""Maybe schedule new items on the node.
If there are any globally pending nodes left then this will
@ -197,7 +210,7 @@ class LoadScheduling:
self.log("num items waiting for node:", len(self.pending))
def remove_node(self, node):
def remove_node(self, node: WorkerController) -> str | None:
"""Remove a node from the scheduler.
This should be called either when the node crashed or at
@ -212,16 +225,17 @@ class LoadScheduling:
"""
pending = self.node2pending.pop(node)
if not pending:
return
return None
# The node crashed, reassing pending items
assert self.collection is not None
crashitem = self.collection[pending.pop(0)]
self.pending.extend(pending)
for node in self.node2pending:
self.check_schedule(node)
return crashitem
def schedule(self):
def schedule(self) -> None:
"""Initiate distribution of the test collection.
Initiate scheduling of the items across the nodes. If this
@ -285,14 +299,14 @@ class LoadScheduling:
for node in self.nodes:
node.shutdown()
def _send_tests(self, node, num):
def _send_tests(self, node: WorkerController, num: int) -> None:
tests_per_node = self.pending[:num]
if tests_per_node:
del self.pending[:num]
self.node2pending[node].extend(tests_per_node)
node.send_runtest_some(tests_per_node)
def _check_nodes_have_same_collection(self):
def _check_nodes_have_same_collection(self) -> bool:
"""Return True if all nodes have collected the same items.
If collections differ, this method returns False while logging

View File

@ -1,3 +1,7 @@
from __future__ import annotations
import pytest
from xdist.remote import Producer
from .loadscope import LoadScopeScheduling
@ -21,14 +25,14 @@ class LoadFileScheduling(LoadScopeScheduling):
This class behaves very much like LoadScopeScheduling, but with a file-level scope.
"""
def __init__(self, config, log=None):
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
super().__init__(config, log)
if log is None:
self.log = Producer("loadfilesched")
else:
self.log = log.loadfilesched
def _split_scope(self, nodeid):
def _split_scope(self, nodeid: str) -> str:
"""Determine the scope (grouping) of a nodeid.
There are usually 3 cases for a nodeid::

View File

@ -1,3 +1,7 @@
from __future__ import annotations
import pytest
from xdist.remote import Producer
from .loadscope import LoadScopeScheduling
@ -10,14 +14,14 @@ class LoadGroupScheduling(LoadScopeScheduling):
instead of the module or class to which they belong to.
"""
def __init__(self, config, log=None):
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
super().__init__(config, log)
if log is None:
self.log = Producer("loadgroupsched")
else:
self.log = log.loadgroupsched
def _split_scope(self, nodeid):
def _split_scope(self, nodeid: str) -> str:
"""Determine the scope (grouping) of a nodeid.
There are usually 3 cases for a nodeid::

View File

@ -1,10 +1,15 @@
from __future__ import annotations
from collections import OrderedDict
from typing import NoReturn
from typing import Sequence
import pytest
from xdist.remote import Producer
from xdist.report import report_collection_diff
from xdist.workermanage import parse_spec_config
from xdist.workermanage import WorkerController
class LoadScopeScheduling:
@ -85,13 +90,13 @@ class LoadScopeScheduling:
:config: Config object, used for handling hooks.
"""
def __init__(self, config, log=None):
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
self.numnodes = len(parse_spec_config(config))
self.collection = None
self.collection: list[str] | None = None
self.workqueue = OrderedDict()
self.assigned_work = {}
self.registered_collections = {}
self.workqueue: OrderedDict[str, dict[str, bool]] = OrderedDict()
self.assigned_work: dict[WorkerController, dict[str, dict[str, bool]]] = {}
self.registered_collections: dict[WorkerController, list[str]] = {}
if log is None:
self.log = Producer("loadscopesched")
@ -101,12 +106,12 @@ class LoadScopeScheduling:
self.config = config
@property
def nodes(self):
def nodes(self) -> list[WorkerController]:
"""A list of all active nodes in the scheduler."""
return list(self.assigned_work.keys())
@property
def collection_is_completed(self):
def collection_is_completed(self) -> bool:
"""Boolean indication initial test collection is complete.
This is a boolean indicating all initial participating nodes have
@ -116,7 +121,7 @@ class LoadScopeScheduling:
return len(self.registered_collections) >= self.numnodes
@property
def tests_finished(self):
def tests_finished(self) -> bool:
"""Return True if all tests have been executed by the nodes."""
if not self.collection_is_completed:
return False
@ -131,7 +136,7 @@ class LoadScopeScheduling:
return True
@property
def has_pending(self):
def has_pending(self) -> bool:
"""Return True if there are pending test items.
This indicates that collection has finished and nodes are still
@ -147,7 +152,7 @@ class LoadScopeScheduling:
return False
def add_node(self, node):
def add_node(self, node: WorkerController) -> None:
"""Add a new node to the scheduler.
From now on the node will be assigned work units to be executed.
@ -158,7 +163,7 @@ class LoadScopeScheduling:
assert node not in self.assigned_work
self.assigned_work[node] = {}
def remove_node(self, node):
def remove_node(self, node: WorkerController) -> str | None:
"""Remove a node from the scheduler.
This should be called either when the node crashed or at shutdown time.
@ -199,7 +204,9 @@ class LoadScopeScheduling:
return crashitem
def add_node_collection(self, node, collection):
def add_node_collection(
self, node: WorkerController, collection: Sequence[str]
) -> None:
"""Add the collected test items from a node.
The collection is stored in the ``.registered_collections`` dictionary.
@ -228,7 +235,9 @@ class LoadScopeScheduling:
self.registered_collections[node] = list(collection)
def mark_test_complete(self, node, item_index, duration=0):
def mark_test_complete(
self, node: WorkerController, item_index: int, duration: float = 0
) -> None:
"""Mark test item as completed by node.
Called by the hook:
@ -241,13 +250,17 @@ class LoadScopeScheduling:
self.assigned_work[node][scope][nodeid] = True
self._reschedule(node)
def mark_test_pending(self, item):
def mark_test_pending(self, item: str) -> NoReturn:
raise NotImplementedError()
def remove_pending_tests_from_node(self, node, indices):
def remove_pending_tests_from_node(
self,
node: WorkerController,
indices: Sequence[int],
) -> None:
raise NotImplementedError()
def _assign_work_unit(self, node):
def _assign_work_unit(self, node: WorkerController) -> None:
"""Assign a work unit to a node."""
assert self.workqueue
@ -268,7 +281,7 @@ class LoadScopeScheduling:
node.send_runtest_some(nodeids_indexes)
def _split_scope(self, nodeid):
def _split_scope(self, nodeid: str) -> str:
"""Determine the scope (grouping) of a nodeid.
There are usually 3 cases for a nodeid::
@ -292,12 +305,12 @@ class LoadScopeScheduling:
"""
return nodeid.rsplit("::", 1)[0]
def _pending_of(self, workload):
def _pending_of(self, workload: dict[str, dict[str, bool]]) -> int:
"""Return the number of pending tests in a workload."""
pending = sum(list(scope.values()).count(False) for scope in workload.values())
return pending
def _reschedule(self, node):
def _reschedule(self, node: WorkerController) -> None:
"""Maybe schedule new items on the node.
If there are any globally pending work units left then this will check
@ -322,7 +335,7 @@ class LoadScopeScheduling:
# Pop one unit of work and assign it
self._assign_work_unit(node)
def schedule(self):
def schedule(self) -> None:
"""Initiate distribution of the test collection.
Initiate scheduling of the items across the nodes. If this gets called
@ -352,7 +365,7 @@ class LoadScopeScheduling:
return
# Determine chunks of work (scopes)
unsorted_workqueue = {}
unsorted_workqueue: dict[str, dict[str, bool]] = {}
for nodeid in self.collection:
scope = self._split_scope(nodeid)
work_unit = unsorted_workqueue.setdefault(scope, {})
@ -389,7 +402,7 @@ class LoadScopeScheduling:
for node in self.nodes:
node.shutdown()
def _check_nodes_have_same_collection(self):
def _check_nodes_have_same_collection(self) -> bool:
"""Return True if all nodes have collected the same items.
If collections differ, this method returns False while logging

View File

@ -1,17 +1,18 @@
from __future__ import annotations
from typing import Any
from typing import NamedTuple
from typing import Sequence
import pytest
from xdist.remote import Producer
from xdist.report import report_collection_diff
from xdist.workermanage import parse_spec_config
from xdist.workermanage import WorkerController
class NodePending(NamedTuple):
node: Any
node: WorkerController
pending: list[int]
@ -63,26 +64,26 @@ class WorkStealingScheduling:
simultaneous requests.
"""
def __init__(self, config, log=None):
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
self.numnodes = len(parse_spec_config(config))
self.node2collection = {}
self.node2pending = {}
self.pending = []
self.collection = None
self.node2collection: dict[WorkerController, list[str]] = {}
self.node2pending: dict[WorkerController, list[int]] = {}
self.pending: list[int] = []
self.collection: list[str] | None = None
if log is None:
self.log = Producer("workstealsched")
else:
self.log = log.workstealsched
self.config = config
self.steal_requested_from_node = None
self.steal_requested_from_node: WorkerController | None = None
@property
def nodes(self):
def nodes(self) -> list[WorkerController]:
"""A list of all nodes in the scheduler."""
return list(self.node2pending.keys())
@property
def collection_is_completed(self):
def collection_is_completed(self) -> bool:
"""Boolean indication initial test collection is complete.
This is a boolean indicating all initial participating nodes
@ -92,7 +93,7 @@ class WorkStealingScheduling:
return len(self.node2collection) >= self.numnodes
@property
def tests_finished(self):
def tests_finished(self) -> bool:
"""Return True if all tests have been executed by the nodes."""
if not self.collection_is_completed:
return False
@ -106,7 +107,7 @@ class WorkStealingScheduling:
return True
@property
def has_pending(self):
def has_pending(self) -> bool:
"""Return True if there are pending test items.
This indicates that collection has finished and nodes are
@ -120,7 +121,7 @@ class WorkStealingScheduling:
return True
return False
def add_node(self, node):
def add_node(self, node: WorkerController) -> None:
"""Add a new node to the scheduler.
From now on the node will be allocated chunks of tests to
@ -132,7 +133,9 @@ class WorkStealingScheduling:
assert node not in self.node2pending
self.node2pending[node] = []
def add_node_collection(self, node, collection):
def add_node_collection(
self, node: WorkerController, collection: Sequence[str]
) -> None:
"""Add the collected test items from a node.
The collection is stored in the ``.node2collection`` map.
@ -153,7 +156,9 @@ class WorkStealingScheduling:
return
self.node2collection[node] = list(collection)
def mark_test_complete(self, node, item_index, duration=None):
def mark_test_complete(
self, node: WorkerController, item_index: int, duration: float | None = None
) -> None:
"""Mark test item as completed by node.
This is called by the ``DSession.worker_testreport`` hook.
@ -161,14 +166,19 @@ class WorkStealingScheduling:
self.node2pending[node].remove(item_index)
self.check_schedule()
def mark_test_pending(self, item):
def mark_test_pending(self, item: str) -> None:
assert self.collection is not None
self.pending.insert(
0,
self.collection.index(item),
)
self.check_schedule()
def remove_pending_tests_from_node(self, node, indices):
def remove_pending_tests_from_node(
self,
node: WorkerController,
indices: Sequence[int],
) -> None:
"""Node returned some test indices back in response to 'steal' command.
This is called by ``DSession.worker_unscheduled``.
@ -183,7 +193,7 @@ class WorkStealingScheduling:
self.pending.extend(indices)
self.check_schedule()
def check_schedule(self):
def check_schedule(self) -> None:
"""Reschedule tests/perform load balancing."""
nodes_up = [
NodePending(node, pending)
@ -191,7 +201,7 @@ class WorkStealingScheduling:
if not node.shutting_down
]
def get_idle_nodes():
def get_idle_nodes() -> list[WorkerController]:
return [node for node, pending in nodes_up if len(pending) < MIN_PENDING]
idle_nodes = get_idle_nodes()
@ -235,10 +245,11 @@ class WorkStealingScheduling:
node.shutdown()
return
assert steal_from is not None
steal_from.node.send_steal(steal_from.pending[-num_steal:])
self.steal_requested_from_node = steal_from.node
def remove_node(self, node):
def remove_node(self, node: WorkerController) -> str | None:
"""Remove a node from the scheduler.
This should be called either when the node crashed or at
@ -249,12 +260,12 @@ class WorkStealingScheduling:
Return the item which was being executing while the node
crashed or None if the node has no more pending items.
"""
pending = self.node2pending.pop(node)
# If node was removed without completing its assigned tests - it crashed
if pending:
assert self.collection is not None
crashitem = self.collection[pending.pop(0)]
else:
crashitem = None
@ -268,7 +279,7 @@ class WorkStealingScheduling:
self.check_schedule()
return crashitem
def schedule(self):
def schedule(self) -> None:
"""Initiate distribution of the test collection.
Initiate scheduling of the items across the nodes. If this
@ -298,14 +309,14 @@ class WorkStealingScheduling:
self.check_schedule()
def _send_tests(self, node, num):
def _send_tests(self, node: WorkerController, num: int) -> None:
tests_per_node = self.pending[:num]
if tests_per_node:
del self.pending[:num]
self.node2pending[node].extend(tests_per_node)
node.send_runtest_some(tests_per_node)
def _check_nodes_have_same_collection(self):
def _check_nodes_have_same_collection(self) -> bool:
"""Return True if all nodes have collected the same items.
If collections differ, this method returns False while logging

View File

@ -7,9 +7,12 @@ from pathlib import Path
import re
import sys
from typing import Any
from typing import Callable
from typing import Literal
from typing import Sequence
from typing import Union
import uuid
import warnings
import execnet
import pytest
@ -17,11 +20,13 @@ import pytest
from xdist.plugin import _sys_path
import xdist.remote
from xdist.remote import Producer
from xdist.remote import WorkerInfo
def parse_spec_config(config):
def parse_spec_config(config: pytest.Config) -> list[str]:
xspeclist = []
for xspec in config.getvalue("tx"):
tx: list[str] = config.getvalue("tx")
for xspec in tx:
i = xspec.find("*")
try:
num = int(xspec[:i])
@ -40,7 +45,12 @@ class NodeManager:
EXIT_TIMEOUT = 10
DEFAULT_IGNORES = [".*", "*.pyc", "*.pyo", "*~"]
def __init__(self, config, specs=None, defaultchdir="pyexecnetcache") -> None:
def __init__(
self,
config: pytest.Config,
specs: Sequence[execnet.XSpec | str] | None = None,
defaultchdir: str = "pyexecnetcache",
) -> None:
self.config = config
self.trace = self.config.trace.get("nodemanager")
self.testrunuid = self.config.getoption("testrunuid")
@ -49,7 +59,7 @@ class NodeManager:
self.group = execnet.Group()
if specs is None:
specs = self._getxspecs()
self.specs = []
self.specs: list[execnet.XSpec] = []
for spec in specs:
if not isinstance(spec, execnet.XSpec):
spec = execnet.XSpec(spec)
@ -61,31 +71,39 @@ class NodeManager:
self.rsyncoptions = self._getrsyncoptions()
self._rsynced_specs: set[tuple[Any, Any]] = set()
def rsync_roots(self, gateway):
def rsync_roots(self, gateway: execnet.Gateway) -> None:
"""Rsync the set of roots to the node's gateway cwd."""
if self.roots:
for root in self.roots:
self.rsync(gateway, root, **self.rsyncoptions)
def setup_nodes(self, putevent):
def setup_nodes(
self,
putevent: Callable[[tuple[str, dict[str, Any]]], None],
) -> list[WorkerController]:
self.config.hook.pytest_xdist_setupnodes(config=self.config, specs=self.specs)
self.trace("setting up nodes")
return [self.setup_node(spec, putevent) for spec in self.specs]
def setup_node(self, spec, putevent):
def setup_node(
self,
spec: execnet.XSpec,
putevent: Callable[[tuple[str, dict[str, Any]]], None],
) -> WorkerController:
gw = self.group.makegateway(spec)
self.config.hook.pytest_xdist_newgateway(gateway=gw)
self.rsync_roots(gw)
node = WorkerController(self, gw, self.config, putevent)
gw.node = node # keep the node alive
# Keep the node alive.
gw.node = node # type: ignore[attr-defined]
node.setup()
self.trace("started node %r" % node)
return node
def teardown_nodes(self):
def teardown_nodes(self) -> None:
self.group.terminate(self.EXIT_TIMEOUT)
def _getxspecs(self):
def _getxspecs(self) -> list[execnet.XSpec]:
return [execnet.XSpec(x) for x in parse_spec_config(self.config)]
def _getrsyncdirs(self) -> list[Path]:
@ -97,7 +115,7 @@ class NodeManager:
import _pytest
import pytest
def get_dir(p):
def get_dir(p: str) -> str:
"""Return the directory path if p is a package or the path to the .py file otherwise."""
stripped = p.rstrip("co")
if os.path.basename(stripped) == "__init__.py":
@ -115,14 +133,14 @@ class NodeManager:
candidates.extend(rsyncroots)
roots = []
for root in candidates:
root = Path(root).resolve()
if not root.exists():
root_path = Path(root).resolve()
if not root_path.exists():
raise pytest.UsageError(f"rsyncdir doesn't exist: {root!r}")
if root not in roots:
roots.append(root)
if root_path not in roots:
roots.append(root_path)
return roots
def _getrsyncoptions(self):
def _getrsyncoptions(self) -> dict[str, Any]:
"""Get options to be passed for rsync."""
ignores = list(self.DEFAULT_IGNORES)
ignores += [str(path) for path in self.config.option.rsyncignore]
@ -133,7 +151,16 @@ class NodeManager:
"verbose": getattr(self.config.option, "verbose", 0),
}
def rsync(self, gateway, source, notify=None, verbose=False, ignores=None):
def rsync(
self,
gateway: execnet.Gateway,
source: str | os.PathLike[str],
notify: (
Callable[[str, execnet.XSpec, str | os.PathLike[str]], Any] | None
) = None,
verbose: int = False,
ignores: Sequence[str] | None = None,
) -> None:
"""Perform rsync to remote hosts for node."""
# XXX This changes the calling behaviour of
# pytest_xdist_rsyncstart and pytest_xdist_rsyncfinish to
@ -153,7 +180,7 @@ class NodeManager:
if (spec, source) in self._rsynced_specs:
return
def finished():
def finished() -> None:
if notify:
notify("rsyncrootready", spec, source)
@ -189,11 +216,19 @@ class HostRSync(execnet.RSync):
else:
return True
def add_target_host(self, gateway, finished=None):
def add_target_host(
self,
gateway: execnet.Gateway,
finished: Callable[[], None] | None = None,
) -> None:
remotepath = os.path.basename(self._sourcedir)
super().add_target(gateway, remotepath, finishedcallback=finished, delete=True)
def _report_send_file(self, gateway, modified_rel_path):
def _report_send_file(
self,
gateway: execnet.Gateway, # type: ignore[override]
modified_rel_path: str,
) -> None:
if self._verbose > 0:
path = os.path.basename(self._sourcedir) + "/" + modified_rel_path
remotepath = gateway.spec.chdir
@ -234,12 +269,21 @@ class Marker(enum.Enum):
class WorkerController:
# Set when the worker is ready.
workerinfo: WorkerInfo
class RemoteHook:
@pytest.hookimpl(trylast=True)
def pytest_xdist_getremotemodule(self):
def pytest_xdist_getremotemodule(self) -> Any:
return xdist.remote
def __init__(self, nodemanager, gateway, config, putevent):
def __init__(
self,
nodemanager: NodeManager,
gateway: execnet.Gateway,
config: pytest.Config,
putevent: Callable[[tuple[str, dict[str, Any]]], None],
) -> None:
config.pluginmanager.register(self.RemoteHook())
self.nodemanager = nodemanager
self.putevent = putevent
@ -255,14 +299,14 @@ class WorkerController:
self._shutdown_sent = False
self.log = Producer(f"workerctl-{gateway.id}", enabled=config.option.debug)
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.gateway.id}>"
@property
def shutting_down(self):
def shutting_down(self) -> bool:
return self._down or self._shutdown_sent
def setup(self):
def setup(self) -> None:
self.log("setting up worker session")
spec = self.gateway.spec
args = [str(x) for x in self.config.invocation_params.args or ()]
@ -283,10 +327,11 @@ class WorkerController:
change_sys_path = _sys_path if self.gateway.spec.popen else None
self.channel.send((self.workerinput, args, option_dict, change_sys_path))
if self.putevent:
# putevent is only None in a test.
if self.putevent: # type: ignore[truthy-function]
self.channel.setcallback(self.process_from_remote, endmarker=Marker.END)
def ensure_teardown(self):
def ensure_teardown(self) -> None:
if hasattr(self, "channel"):
if not self.channel.isclosed():
self.log("closing", self.channel)
@ -297,16 +342,16 @@ class WorkerController:
self.gateway.exit()
# del self.gateway
def send_runtest_some(self, indices):
def send_runtest_some(self, indices: Sequence[int]) -> None:
self.sendcommand("runtests", indices=indices)
def send_runtest_all(self):
def send_runtest_all(self) -> None:
self.sendcommand("runtests_all")
def send_steal(self, indices):
def send_steal(self, indices: Sequence[int]) -> None:
self.sendcommand("steal", indices=indices)
def shutdown(self):
def shutdown(self) -> None:
if not self._down:
try:
self.sendcommand("shutdown")
@ -314,16 +359,18 @@ class WorkerController:
pass
self._shutdown_sent = True
def sendcommand(self, name, **kwargs):
def sendcommand(self, name: str, **kwargs: object) -> None:
"""Send a named parametrized command to the other side."""
self.log(f"sending command {name}(**{kwargs})")
self.channel.send((name, kwargs))
def notify_inproc(self, eventname, **kwargs):
def notify_inproc(self, eventname: str, **kwargs: object) -> None:
self.log(f"queuing {eventname}(**{kwargs})")
self.putevent((eventname, kwargs))
def process_from_remote(self, eventcall):
def process_from_remote(
self, eventcall: tuple[str, dict[str, Any]] | Literal[Marker.END]
) -> None:
"""This gets called for each object we receive from
the other side and if the channel closes.
@ -333,7 +380,7 @@ class WorkerController:
"""
try:
if eventcall is Marker.END:
err = self.channel._getremoteerror()
err: object | None = self.channel._getremoteerror() # type: ignore[no-untyped-call]
if not self._down:
if not err or isinstance(err, EOFError):
err = "Not properly terminated" # lost connection?
@ -399,9 +446,8 @@ class WorkerController:
self.notify_inproc("errordown", node=self, error=excinfo)
def unserialize_warning_message(data):
def unserialize_warning_message(data: dict[str, Any]) -> warnings.WarningMessage:
import importlib
import warnings
if data["message_module"]:
mod = importlib.import_module(data["message_module"])
@ -438,4 +484,4 @@ def unserialize_warning_message(data):
continue
kwargs[attr_name] = data[attr_name]
return warnings.WarningMessage(**kwargs) # type: ignore[arg-type]
return warnings.WarningMessage(**kwargs)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import os
import re
import shutil
from typing import cast
import pytest
@ -223,7 +224,7 @@ class TestDistribution:
assert result.ret == 1
def test_distribution_rsyncdirs_example(
self, pytester: pytest.Pytester, monkeypatch
self, pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch
) -> None:
# use a custom plugin that has a custom command-line option to ensure
# this is propagated to workers (see #491)
@ -415,7 +416,7 @@ class TestDistEach:
class TestTerminalReporting:
@pytest.mark.parametrize("verbosity", ["", "-q", "-v"])
def test_output_verbosity(self, pytester, verbosity: str) -> None:
def test_output_verbosity(self, pytester: pytest.Pytester, verbosity: str) -> None:
pytester.makepyfile(
"""
def test_ok():
@ -610,7 +611,7 @@ def test_fixture_teardown_failure(pytester: pytest.Pytester) -> None:
def test_config_initialization(
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, pytestconfig
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure workers and controller are initialized consistently. Integration test for #445."""
pytester.makepyfile(
@ -635,7 +636,7 @@ def test_config_initialization(
@pytest.mark.parametrize("when", ["setup", "call", "teardown"])
def test_crashing_item(pytester, when) -> None:
def test_crashing_item(pytester: pytest.Pytester, when: str) -> None:
"""Ensure crashing item is correctly reported during all testing stages."""
code = dict(setup="", call="", teardown="")
code[when] = "os._exit(1)"
@ -766,7 +767,7 @@ def test_tmpdir_disabled(pytester: pytest.Pytester) -> None:
@pytest.mark.parametrize("plugin", ["xdist.looponfail"])
def test_sub_plugins_disabled(pytester, plugin) -> None:
def test_sub_plugins_disabled(pytester: pytest.Pytester, plugin: str) -> None:
"""Test that xdist doesn't break if we disable any of its sub-plugins (#32)."""
p1 = pytester.makepyfile(
"""
@ -781,7 +782,7 @@ def test_sub_plugins_disabled(pytester, plugin) -> None:
class TestWarnings:
@pytest.mark.parametrize("n", ["-n0", "-n1"])
def test_warnings(self, pytester, n) -> None:
def test_warnings(self, pytester: pytest.Pytester, n: str) -> None:
pytester.makepyfile(
"""
import warnings, py, pytest
@ -827,7 +828,7 @@ class TestWarnings:
result.stdout.no_fnmatch_line("*this hook should not be called in this version")
@pytest.mark.parametrize("n", ["-n0", "-n1"])
def test_custom_subclass(self, pytester, n) -> None:
def test_custom_subclass(self, pytester: pytest.Pytester, n: str) -> None:
"""Check that warning subclasses that don't honor the args attribute don't break
pytest-xdist (#344).
"""
@ -851,7 +852,7 @@ class TestWarnings:
result.stdout.fnmatch_lines(["*MyWarning*", "*1 passed, 1 warning*"])
@pytest.mark.parametrize("n", ["-n0", "-n1"])
def test_unserializable_arguments(self, pytester, n) -> None:
def test_unserializable_arguments(self, pytester: pytest.Pytester, n: str) -> None:
"""Check that warnings with unserializable arguments are handled correctly (#349)."""
pytester.makepyfile(
"""
@ -869,7 +870,9 @@ class TestWarnings:
result.stdout.fnmatch_lines(["*UserWarning*foo.txt*", "*1 passed, 1 warning*"])
@pytest.mark.parametrize("n", ["-n0", "-n1"])
def test_unserializable_warning_details(self, pytester, n) -> None:
def test_unserializable_warning_details(
self, pytester: pytest.Pytester, n: str
) -> None:
"""Check that warnings with unserializable _WARNING_DETAILS are
handled correctly (#379).
"""
@ -1049,7 +1052,7 @@ class TestNodeFailure:
@pytest.mark.parametrize("n", [0, 2])
def test_worker_id_fixture(pytester, n) -> None:
def test_worker_id_fixture(pytester: pytest.Pytester, n: int) -> None:
import glob
f = pytester.makepyfile(
@ -1065,8 +1068,8 @@ def test_worker_id_fixture(pytester, n) -> None:
result.stdout.fnmatch_lines("* 2 passed in *")
worker_ids = set()
for fname in glob.glob(str(pytester.path / "*.txt")):
with open(fname) as f:
worker_ids.add(f.read().strip())
with open(fname) as fp:
worker_ids.add(fp.read().strip())
if n == 0:
assert worker_ids == {"master"}
else:
@ -1074,7 +1077,7 @@ def test_worker_id_fixture(pytester, n) -> None:
@pytest.mark.parametrize("n", [0, 2])
def test_testrun_uid_fixture(pytester, n) -> None:
def test_testrun_uid_fixture(pytester: pytest.Pytester, n: int) -> None:
import glob
f = pytester.makepyfile(
@ -1090,14 +1093,14 @@ def test_testrun_uid_fixture(pytester, n) -> None:
result.stdout.fnmatch_lines("* 2 passed in *")
testrun_uids = set()
for fname in glob.glob(str(pytester.path / "*.txt")):
with open(fname) as f:
testrun_uids.add(f.read().strip())
with open(fname) as fp:
testrun_uids.add(fp.read().strip())
assert len(testrun_uids) == 1
assert len(testrun_uids.pop()) == 32
@pytest.mark.parametrize("tb", ["auto", "long", "short", "no", "line", "native"])
def test_error_report_styles(pytester, tb) -> None:
def test_error_report_styles(pytester: pytest.Pytester, tb: str) -> None:
pytester.makepyfile(
"""
import pytest
@ -1111,7 +1114,7 @@ def test_error_report_styles(pytester, tb) -> None:
result.assert_outcomes(failed=1)
def test_color_yes_collection_on_non_atty(pytester) -> None:
def test_color_yes_collection_on_non_atty(pytester: pytest.Pytester) -> None:
"""Skip collect progress report when working on non-terminals.
Similar to pytest-dev/pytest#1397
@ -1133,7 +1136,7 @@ def test_color_yes_collection_on_non_atty(pytester) -> None:
assert "collecting:" not in result.stdout.str()
def test_without_terminal_plugin(pytester, request) -> None:
def test_without_terminal_plugin(pytester: pytest.Pytester) -> None:
"""No output when terminal plugin is disabled."""
pytester.makepyfile(
"""
@ -1368,7 +1371,7 @@ class TestFileScope:
class TestGroupScope:
def test_by_module(self, pytester: pytest.Pytester):
def test_by_module(self, pytester: pytest.Pytester) -> None:
test_file = """
import pytest
class TestA:
@ -1399,7 +1402,7 @@ class TestGroupScope:
== test_b_workers_and_test_count.items()
)
def test_by_class(self, pytester: pytest.Pytester):
def test_by_class(self, pytester: pytest.Pytester) -> None:
pytester.makepyfile(
test_a="""
import pytest
@ -1436,7 +1439,7 @@ class TestGroupScope:
== test_b_workers_and_test_count.items()
)
def test_module_single_start(self, pytester: pytest.Pytester):
def test_module_single_start(self, pytester: pytest.Pytester) -> None:
test_file1 = """
import pytest
@pytest.mark.xdist_group(name="xdist_group")
@ -1459,7 +1462,7 @@ class TestGroupScope:
assert a.keys() == b.keys() and b.keys() == c.keys()
def test_with_two_group_names(self, pytester: pytest.Pytester):
def test_with_two_group_names(self, pytester: pytest.Pytester) -> None:
test_file = """
import pytest
@pytest.mark.xdist_group(name="group1")
@ -1512,7 +1515,7 @@ class TestLocking:
@pytest.mark.parametrize(
"scope", ["each", "load", "loadscope", "loadfile", "worksteal", "no"]
)
def test_single_file(self, pytester, scope) -> None:
def test_single_file(self, pytester: pytest.Pytester, scope: str) -> None:
pytester.makepyfile(test_a=self.test_file1)
result = pytester.runpytest("-n2", "--dist=%s" % scope, "-v")
result.assert_outcomes(passed=(12 if scope != "each" else 12 * 2))
@ -1520,7 +1523,7 @@ class TestLocking:
@pytest.mark.parametrize(
"scope", ["each", "load", "loadscope", "loadfile", "worksteal", "no"]
)
def test_multi_file(self, pytester, scope) -> None:
def test_multi_file(self, pytester: pytest.Pytester, scope: str) -> None:
pytester.makepyfile(
test_a=self.test_file1,
test_b=self.test_file1,
@ -1564,32 +1567,32 @@ def get_workers_and_test_count_by_prefix(
class TestAPI:
@pytest.fixture
def fake_request(self):
def fake_request(self) -> pytest.FixtureRequest:
class FakeOption:
def __init__(self):
def __init__(self) -> None:
self.dist = "load"
class FakeConfig:
def __init__(self):
def __init__(self) -> None:
self.workerinput = {"workerid": "gw5"}
self.option = FakeOption()
class FakeRequest:
def __init__(self):
def __init__(self) -> None:
self.config = FakeConfig()
return FakeRequest()
return cast(pytest.FixtureRequest, FakeRequest())
def test_is_xdist_worker(self, fake_request) -> None:
def test_is_xdist_worker(self, fake_request: pytest.FixtureRequest) -> None:
assert xdist.is_xdist_worker(fake_request)
del fake_request.config.workerinput
del fake_request.config.workerinput # type: ignore[attr-defined]
assert not xdist.is_xdist_worker(fake_request)
def test_is_xdist_controller(self, fake_request) -> None:
def test_is_xdist_controller(self, fake_request: pytest.FixtureRequest) -> None:
assert not xdist.is_xdist_master(fake_request)
assert not xdist.is_xdist_controller(fake_request)
del fake_request.config.workerinput
del fake_request.config.workerinput # type: ignore[attr-defined]
assert xdist.is_xdist_master(fake_request)
assert xdist.is_xdist_controller(fake_request)
@ -1597,13 +1600,13 @@ class TestAPI:
assert not xdist.is_xdist_master(fake_request)
assert not xdist.is_xdist_controller(fake_request)
def test_get_xdist_worker_id(self, fake_request) -> None:
def test_get_xdist_worker_id(self, fake_request: pytest.FixtureRequest) -> None:
assert xdist.get_xdist_worker_id(fake_request) == "gw5"
del fake_request.config.workerinput
del fake_request.config.workerinput # type: ignore[attr-defined]
assert xdist.get_xdist_worker_id(fake_request) == "master"
def test_collection_crash(pytester: pytest.Pytester):
def test_collection_crash(pytester: pytest.Pytester) -> None:
p1 = pytester.makepyfile(
"""
assert 0
@ -1622,7 +1625,7 @@ def test_collection_crash(pytester: pytest.Pytester):
)
def test_dist_in_addopts(pytester: pytest.Pytester):
def test_dist_in_addopts(pytester: pytest.Pytester) -> None:
"""Users can set a default distribution in the configuration file (#789)."""
pytester.makepyfile(
"""

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import shutil
from typing import Callable
from typing import Generator
import execnet
import pytest
@ -10,12 +12,14 @@ pytest_plugins = "pytester"
@pytest.fixture(autouse=True)
def _divert_atexit(request, monkeypatch: pytest.MonkeyPatch):
def _divert_atexit(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]:
import atexit
finalizers = []
def fake_register(func, *args, **kwargs):
def fake_register(
func: Callable[..., object], *args: object, **kwargs: object
) -> None:
finalizers.append((func, args, kwargs))
monkeypatch.setattr(atexit, "register", fake_register)
@ -27,7 +31,7 @@ def _divert_atexit(request, monkeypatch: pytest.MonkeyPatch):
func(*args, **kwargs)
def pytest_addoption(parser) -> None:
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--gx",
action="append",
@ -37,16 +41,16 @@ def pytest_addoption(parser) -> None:
@pytest.fixture
def specssh(request) -> str:
def specssh(request: pytest.FixtureRequest) -> str:
return getspecssh(request.config)
# configuration information for tests
def getgspecs(config) -> list[execnet.XSpec]:
def getgspecs(config: pytest.Config) -> list[execnet.XSpec]:
return [execnet.XSpec(spec) for spec in config.getvalueorskip("gspecs")]
def getspecssh(config) -> str: # type: ignore[return]
def getspecssh(config: pytest.Config) -> str:
xspecs = getgspecs(config)
for spec in xspecs:
if spec.ssh:
@ -56,7 +60,7 @@ def getspecssh(config) -> str: # type: ignore[return]
pytest.skip("need '--gx ssh=...'")
def getsocketspec(config) -> execnet.XSpec:
def getsocketspec(config: pytest.Config) -> execnet.XSpec:
xspecs = getgspecs(config)
for spec in xspecs:
if spec.socket:

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from typing import Any
from typing import cast
from typing import Sequence
from typing import TYPE_CHECKING
import execnet
import pytest
@ -13,29 +16,38 @@ from xdist.report import report_collection_diff
from xdist.scheduler import EachScheduling
from xdist.scheduler import LoadScheduling
from xdist.scheduler import WorkStealingScheduling
from xdist.workermanage import WorkerController
class MockGateway:
if TYPE_CHECKING:
BaseOfMockGateway = execnet.Gateway
BaseOfMockNode = WorkerController
else:
BaseOfMockGateway = object
BaseOfMockNode = object
class MockGateway(BaseOfMockGateway):
def __init__(self) -> None:
self._count = 0
self.id = str(self._count)
self._count += 1
class MockNode:
class MockNode(BaseOfMockNode):
def __init__(self) -> None:
self.sent = [] # type: ignore[var-annotated]
self.stolen = [] # type: ignore[var-annotated]
self.sent: list[int | str] = []
self.stolen: list[int] = []
self.gateway = MockGateway()
self._shutdown = False
def send_runtest_some(self, indices) -> None:
def send_runtest_some(self, indices: Sequence[int]) -> None:
self.sent.extend(indices)
def send_runtest_all(self) -> None:
self.sent.append("ALL")
def send_steal(self, indices) -> None:
def send_steal(self, indices: Sequence[int]) -> None:
self.stolen.extend(indices)
def shutdown(self) -> None:
@ -48,10 +60,9 @@ class MockNode:
class TestEachScheduling:
def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None:
node1 = MockNode()
node2 = MockNode()
config = pytester.parseconfig("--tx=2*popen")
sched = EachScheduling(config)
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
collection = ["a.py::test_1"]
@ -59,7 +70,7 @@ class TestEachScheduling:
sched.add_node_collection(node1, collection)
assert not sched.collection_is_completed
sched.add_node_collection(node2, collection)
assert sched.collection_is_completed
assert bool(sched.collection_is_completed)
assert sched.node2collection[node1] == collection
assert sched.node2collection[node2] == collection
sched.schedule()
@ -72,14 +83,14 @@ class TestEachScheduling:
assert sched.tests_finished
def test_schedule_remove_node(self, pytester: pytest.Pytester) -> None:
node1 = MockNode()
config = pytester.parseconfig("--tx=popen")
sched = EachScheduling(config)
node1 = MockNode()
sched.add_node(node1)
collection = ["a.py::test_1"]
assert not sched.collection_is_completed
sched.add_node_collection(node1, collection)
assert sched.collection_is_completed
assert bool(sched.collection_is_completed)
assert sched.node2collection[node1] == collection
sched.schedule()
assert sched.tests_finished
@ -93,15 +104,15 @@ class TestLoadScheduling:
def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=2*popen")
sched = LoadScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2 = sched.nodes
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
collection = ["a.py::test_1", "a.py::test_2"]
assert not sched.collection_is_completed
sched.add_node_collection(node1, collection)
assert not sched.collection_is_completed
sched.add_node_collection(node2, collection)
assert sched.collection_is_completed
assert bool(sched.collection_is_completed)
assert sched.node2collection[node1] == collection
assert sched.node2collection[node2] == collection
sched.schedule()
@ -111,15 +122,17 @@ class TestLoadScheduling:
assert len(node2.sent) == 1
assert node1.sent == [0]
assert node2.sent == [1]
sched.mark_test_complete(node1, node1.sent[0])
sent10 = node1.sent[0]
assert isinstance(sent10, int)
sched.mark_test_complete(node1, sent10)
assert sched.tests_finished
def test_schedule_batch_size(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=2*popen")
sched = LoadScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2 = sched.nodes
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
col = ["xyz"] * 6
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -144,9 +157,9 @@ class TestLoadScheduling:
def test_schedule_maxchunk_none(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=2*popen")
sched = LoadScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2 = sched.nodes
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
col = [f"test{i}" for i in range(16)]
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -172,9 +185,9 @@ class TestLoadScheduling:
def test_schedule_maxchunk_1(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=2*popen", "--maxschedchunk=1")
sched = LoadScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2 = sched.nodes
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
col = [f"test{i}" for i in range(16)]
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -186,7 +199,9 @@ class TestLoadScheduling:
assert sched.node2pending[node2] == node2.sent
for complete_index, first_pending in enumerate(range(5, 16)):
sched.mark_test_complete(node1, node1.sent[complete_index])
sent_index = node1.sent[complete_index]
assert isinstance(sent_index, int)
sched.mark_test_complete(node1, sent_index)
assert node1.sent == [0, 1, *range(4, first_pending)]
assert node2.sent == [2, 3]
assert sched.pending == list(range(first_pending, 16))
@ -194,10 +209,10 @@ class TestLoadScheduling:
def test_schedule_fewer_tests_than_nodes(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=3*popen")
sched = LoadScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2, node3 = sched.nodes
node1, node2, node3 = MockNode(), MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
sched.add_node(node3)
col = ["xyz"] * 2
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -215,10 +230,10 @@ class TestLoadScheduling:
) -> None:
config = pytester.parseconfig("--tx=3*popen")
sched = LoadScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2, node3 = sched.nodes
node1, node2, node3 = MockNode(), MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
sched.add_node(node3)
col = ["xyz"] * 5
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -232,9 +247,9 @@ class TestLoadScheduling:
assert not sched.pending
def test_add_remove_node(self, pytester: pytest.Pytester) -> None:
node = MockNode()
config = pytester.parseconfig("--tx=popen")
sched = LoadScheduling(config)
node = MockNode()
sched.add_node(node)
collection = ["test_file.py::test_func"]
sched.add_node_collection(node, collection)
@ -253,18 +268,17 @@ class TestLoadScheduling:
class CollectHook:
"""Dummy hook that stores collection reports."""
def __init__(self):
self.reports = []
def __init__(self) -> None:
self.reports: list[pytest.CollectReport] = []
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: pytest.CollectReport) -> None:
self.reports.append(report)
collect_hook = CollectHook()
config = pytester.parseconfig("--tx=2*popen")
config.pluginmanager.register(collect_hook, "collect_hook")
node1 = MockNode()
node2 = MockNode()
sched = LoadScheduling(config)
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
sched.add_node_collection(node1, ["a.py::test_1"])
@ -272,6 +286,7 @@ class TestLoadScheduling:
sched.schedule()
assert len(collect_hook.reports) == 1
rep = collect_hook.reports[0]
assert isinstance(rep.longrepr, str)
assert "Different tests were collected between" in rep.longrepr
@ -279,15 +294,15 @@ class TestWorkStealingScheduling:
def test_ideal_case(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=2*popen")
sched = WorkStealingScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2 = sched.nodes
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
collection = [f"test_workstealing.py::test_{i}" for i in range(16)]
assert not sched.collection_is_completed
sched.add_node_collection(node1, collection)
assert not sched.collection_is_completed
sched.add_node_collection(node2, collection)
assert sched.collection_is_completed
assert bool(sched.collection_is_completed)
assert sched.node2collection[node1] == collection
assert sched.node2collection[node2] == collection
sched.schedule()
@ -296,18 +311,20 @@ class TestWorkStealingScheduling:
assert node1.sent == list(range(8))
assert node2.sent == list(range(8, 16))
for i in range(8):
sched.mark_test_complete(node1, node1.sent[i])
sched.mark_test_complete(node2, node2.sent[i])
assert sched.tests_finished
sent1, sent2 = node1.sent[i], node2.sent[i]
assert isinstance(sent1, int) and isinstance(sent2, int)
sched.mark_test_complete(node1, sent1)
sched.mark_test_complete(node2, sent2)
assert bool(sched.tests_finished)
assert node1.stolen == []
assert node2.stolen == []
def test_stealing(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=2*popen")
sched = WorkStealingScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2 = sched.nodes
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
collection = [f"test_workstealing.py::test_{i}" for i in range(16)]
sched.add_node_collection(node1, collection)
sched.add_node_collection(node2, collection)
@ -316,11 +333,15 @@ class TestWorkStealingScheduling:
assert node1.sent == list(range(8))
assert node2.sent == list(range(8, 16))
for i in range(8):
sched.mark_test_complete(node1, node1.sent[i])
sent = node1.sent[i]
assert isinstance(sent, int)
sched.mark_test_complete(node1, sent)
assert node2.stolen == list(range(12, 16))
sched.remove_pending_tests_from_node(node2, node2.stolen)
for i in range(4):
sched.mark_test_complete(node2, node2.sent[i])
sent = node2.sent[i]
assert isinstance(sent, int)
sched.mark_test_complete(node2, sent)
assert node1.stolen == [14, 15]
sched.remove_pending_tests_from_node(node1, node1.stolen)
sched.mark_test_complete(node1, 12)
@ -355,10 +376,10 @@ class TestWorkStealingScheduling:
def test_schedule_fewer_tests_than_nodes(self, pytester: pytest.Pytester) -> None:
config = pytester.parseconfig("--tx=3*popen")
sched = WorkStealingScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2, node3 = sched.nodes
node1, node2, node3 = MockNode(), MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
sched.add_node(node3)
col = ["xyz"] * 2
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -378,10 +399,10 @@ class TestWorkStealingScheduling:
) -> None:
config = pytester.parseconfig("--tx=3*popen")
sched = WorkStealingScheduling(config)
sched.add_node(MockNode())
sched.add_node(MockNode())
sched.add_node(MockNode())
node1, node2, node3 = sched.nodes
node1, node2, node3 = MockNode(), MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
sched.add_node(node3)
col = ["xyz"] * 5
sched.add_node_collection(node1, col)
sched.add_node_collection(node2, col)
@ -392,11 +413,19 @@ class TestWorkStealingScheduling:
assert node3.sent == [3, 4]
assert not sched.pending
assert not sched.tests_finished
sched.mark_test_complete(node1, node1.sent[0])
sched.mark_test_complete(node2, node2.sent[0])
sched.mark_test_complete(node3, node3.sent[0])
sched.mark_test_complete(node3, node3.sent[1])
assert sched.tests_finished
sent10 = node1.sent[0]
assert isinstance(sent10, int)
sent20 = node2.sent[0]
assert isinstance(sent20, int)
sent30 = node3.sent[0]
assert isinstance(sent30, int)
sent31 = node3.sent[1]
assert isinstance(sent31, int)
sched.mark_test_complete(node1, sent10)
sched.mark_test_complete(node2, sent20)
sched.mark_test_complete(node3, sent30)
sched.mark_test_complete(node3, sent31)
assert bool(sched.tests_finished)
assert node1.stolen == []
assert node2.stolen == []
assert node3.stolen == []
@ -416,18 +445,17 @@ class TestWorkStealingScheduling:
def test_different_tests_collected(self, pytester: pytest.Pytester) -> None:
class CollectHook:
def __init__(self):
self.reports = []
def __init__(self) -> None:
self.reports: list[pytest.CollectReport] = []
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: pytest.CollectReport) -> None:
self.reports.append(report)
collect_hook = CollectHook()
config = pytester.parseconfig("--tx=2*popen")
config.pluginmanager.register(collect_hook, "collect_hook")
node1 = MockNode()
node2 = MockNode()
sched = WorkStealingScheduling(config)
node1, node2 = MockNode(), MockNode()
sched.add_node(node1)
sched.add_node(node2)
sched.add_node_collection(node1, ["a.py::test_1"])
@ -435,12 +463,13 @@ class TestWorkStealingScheduling:
sched.schedule()
assert len(collect_hook.reports) == 1
rep = collect_hook.reports[0]
assert isinstance(rep.longrepr, str)
assert "Different tests were collected between" in rep.longrepr
class TestDistReporter:
@pytest.mark.xfail
def test_rsync_printing(self, pytester: pytest.Pytester, linecomp) -> None:
def test_rsync_printing(self, pytester: pytest.Pytester, linecomp: Any) -> None:
config = pytester.parseconfig()
from _pytest.terminal import TerminalReporter
@ -473,15 +502,17 @@ class TestDistReporter:
def test_report_collection_diff_equal() -> None:
"""Test reporting of equal collections."""
from_collection = to_collection = ["aaa", "bbb", "ccc"]
assert report_collection_diff(from_collection, to_collection, 1, 2) is None
assert report_collection_diff(from_collection, to_collection, "1", "2") is None
def test_default_max_worker_restart() -> None:
class config:
class MockConfig:
class option:
maxworkerrestart: str | None = None
numprocesses: int = 0
config = cast(pytest.Config, MockConfig)
assert get_default_max_worker_restart(config) is None
config.option.numprocesses = 2

View File

@ -143,7 +143,7 @@ class TestRemoteControl:
control = RemoteControl(modcol.config)
control.loop_once()
assert control.failures
modcol_path = modcol.path # type:ignore[attr-defined]
modcol_path = modcol.path
modcol_path.write_text(
textwrap.dedent(
@ -173,7 +173,7 @@ class TestRemoteControl:
"""
)
)
parent = modcol.path.parent.parent # type: ignore[attr-defined]
parent = modcol.path.parent.parent
monkeypatch.chdir(parent)
modcol.config.args = [
str(Path(x).relative_to(parent)) for x in modcol.config.args
@ -332,7 +332,7 @@ class TestLooponFailing:
remotecontrol = RemoteControl(modcol.config)
orig_runsession = remotecontrol.runsession
def runsession_dups():
def runsession_dups() -> tuple[list[str], list[str], bool]:
# twisted.trial test cases may report multiple errors.
failures, reports, collection_failed = orig_runsession()
print(failures)

View File

@ -10,7 +10,7 @@ from xdist.workermanage import NodeManager
@pytest.fixture
def monkeypatch_3_cpus(monkeypatch: pytest.MonkeyPatch):
def monkeypatch_3_cpus(monkeypatch: pytest.MonkeyPatch) -> None:
"""Make pytest-xdist believe the system has 3 CPUs."""
# block import
monkeypatch.setitem(sys.modules, "psutil", None)
@ -128,7 +128,7 @@ def test_auto_detect_cpus_psutil(
def test_auto_detect_cpus_os(
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None
) -> None:
from xdist.plugin import pytest_cmdline_main as check_options
@ -189,7 +189,7 @@ def test_hook_auto_num_workers_arg(
def test_hook_auto_num_workers_none(
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None
) -> None:
# Returning None from a hook to skip it is pytest behavior,
# but we document it so let's test it.
@ -231,7 +231,7 @@ def test_envvar_auto_num_workers(
def test_envvar_auto_num_workers_warn(
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None
) -> None:
from xdist.plugin import pytest_cmdline_main as check_options
@ -244,7 +244,7 @@ def test_envvar_auto_num_workers_warn(
def test_auto_num_workers_hook_overrides_envvar(
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus
pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None
) -> None:
from xdist.plugin import pytest_cmdline_main as check_options

View File

@ -1,36 +1,46 @@
from __future__ import annotations
import marshal
import pprint
from queue import Queue
import sys
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import Union
import uuid
import execnet
import pytest
from xdist.workermanage import NodeManager
from xdist.workermanage import WorkerController
WAIT_TIMEOUT = 10.0
def check_marshallable(d):
def check_marshallable(d: object) -> None:
try:
marshal.dumps(d)
marshal.dumps(d) # type: ignore[arg-type]
except ValueError as e:
pprint.pprint(d)
raise ValueError("not marshallable") from e
class EventCall:
def __init__(self, eventcall):
def __init__(self, eventcall: tuple[str, dict[str, Any]]) -> None:
self.name, self.kwargs = eventcall
def __str__(self):
def __str__(self) -> str:
return f"<EventCall {self.name}(**{self.kwargs})>"
class WorkerSetup:
def __init__(self, request, pytester: pytest.Pytester) -> None:
def __init__(
self, request: pytest.FixtureRequest, pytester: pytest.Pytester
) -> None:
self.request = request
self.pytester = pytester
self.use_callback = False
@ -47,11 +57,18 @@ class WorkerSetup:
testrunuid = uuid.uuid4().hex
specs = [0, 1]
self.slp = WorkerController(DummyMananger, self.gateway, config, putevent)
nodemanager = cast(NodeManager, DummyMananger)
self.slp = WorkerController(
nodemanager=nodemanager,
gateway=self.gateway,
config=config,
putevent=putevent, # type: ignore[arg-type]
)
self.request.addfinalizer(self.slp.ensure_teardown)
self.slp.setup()
def popevent(self, name=None):
def popevent(self, name: str | None = None) -> EventCall:
while 1:
if self.use_callback:
data = self.events.get(timeout=WAIT_TIMEOUT)
@ -62,27 +79,33 @@ class WorkerSetup:
return ev
print(f"skipping {ev}")
def sendcommand(self, name, **kwargs):
def sendcommand(self, name: str, **kwargs: Any) -> None:
self.slp.sendcommand(name, **kwargs)
@pytest.fixture
def worker(request, pytester: pytest.Pytester) -> WorkerSetup:
def worker(request: pytest.FixtureRequest, pytester: pytest.Pytester) -> WorkerSetup:
return WorkerSetup(request, pytester)
class TestWorkerInteractor:
UnserializerReport = Callable[
[Dict[str, Any]], Union[pytest.CollectReport, pytest.TestReport]
]
@pytest.fixture
def unserialize_report(self, pytestconfig):
def unserialize(data):
return pytestconfig.hook.pytest_report_from_serializable(
def unserialize_report(self, pytestconfig: pytest.Config) -> UnserializerReport:
def unserialize(
data: dict[str, Any],
) -> pytest.CollectReport | pytest.TestReport:
return pytestconfig.hook.pytest_report_from_serializable( # type: ignore[no-any-return]
config=pytestconfig, data=data
)
return unserialize
def test_basic_collect_and_runtests(
self, worker: WorkerSetup, unserialize_report
self, worker: WorkerSetup, unserialize_report: UnserializerReport
) -> None:
worker.pytester.makepyfile(
"""
@ -115,7 +138,9 @@ class TestWorkerInteractor:
ev = worker.popevent("workerfinished")
assert "workeroutput" in ev.kwargs
def test_remote_collect_skip(self, worker: WorkerSetup, unserialize_report) -> None:
def test_remote_collect_skip(
self, worker: WorkerSetup, unserialize_report: UnserializerReport
) -> None:
worker.pytester.makepyfile(
"""
import pytest
@ -129,11 +154,14 @@ class TestWorkerInteractor:
assert ev.name == "collectreport"
rep = unserialize_report(ev.kwargs["data"])
assert rep.skipped
assert isinstance(rep.longrepr, tuple)
assert rep.longrepr[2] == "Skipped: hello"
ev = worker.popevent("collectionfinish")
assert not ev.kwargs["ids"]
def test_remote_collect_fail(self, worker: WorkerSetup, unserialize_report) -> None:
def test_remote_collect_fail(
self, worker: WorkerSetup, unserialize_report: UnserializerReport
) -> None:
worker.pytester.makepyfile("""aasd qwe""")
worker.setup()
ev = worker.popevent("collectionstart")
@ -145,7 +173,9 @@ class TestWorkerInteractor:
ev = worker.popevent("collectionfinish")
assert not ev.kwargs["ids"]
def test_runtests_all(self, worker: WorkerSetup, unserialize_report) -> None:
def test_runtests_all(
self, worker: WorkerSetup, unserialize_report: UnserializerReport
) -> None:
worker.pytester.makepyfile(
"""
def test_func(): pass
@ -205,13 +235,15 @@ class TestWorkerInteractor:
) -> None:
worker.use_callback = True
worker.setup()
worker.slp.process_from_remote(("<nonono>", ()))
worker.slp.process_from_remote(("<nonono>", {}))
out, err = capsys.readouterr()
assert "INTERNALERROR> ValueError: unknown event: <nonono>" in out
ev = worker.popevent()
assert ev.name == "errordown"
def test_steal_work(self, worker: WorkerSetup, unserialize_report) -> None:
def test_steal_work(
self, worker: WorkerSetup, unserialize_report: UnserializerReport
) -> None:
worker.pytester.makepyfile(
"""
import time
@ -262,7 +294,9 @@ class TestWorkerInteractor:
ev = worker.popevent("workerfinished")
assert "workeroutput" in ev.kwargs
def test_steal_empty_queue(self, worker: WorkerSetup, unserialize_report) -> None:
def test_steal_empty_queue(
self, worker: WorkerSetup, unserialize_report: UnserializerReport
) -> None:
worker.pytester.makepyfile(
"""
def test_func(): pass

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pathlib import Path
import shutil
import textwrap
@ -19,13 +21,15 @@ pytest_plugins = "pytester"
@pytest.fixture
def hookrecorder(request, config, pytester: pytest.Pytester):
def hookrecorder(
config: pytest.Config, pytester: pytest.Pytester
) -> pytest.HookRecorder:
hookrecorder = pytester.make_hook_recorder(config.pluginmanager)
return hookrecorder
@pytest.fixture
def config(pytester: pytest.Pytester):
def config(pytester: pytest.Pytester) -> pytest.Config:
return pytester.parseconfig()
@ -44,24 +48,23 @@ def dest(tmp_path: Path) -> Path:
@pytest.fixture
def workercontroller(monkeypatch: pytest.MonkeyPatch):
def workercontroller(monkeypatch: pytest.MonkeyPatch) -> None:
class MockController:
def __init__(self, *args):
def __init__(self, *args: object) -> None:
pass
def setup(self):
def setup(self) -> None:
pass
monkeypatch.setattr(workermanage, "WorkerController", MockController)
return MockController
class TestNodeManagerPopen:
def test_popen_no_default_chdir(self, config) -> None:
def test_popen_no_default_chdir(self, config: pytest.Config) -> None:
gm = NodeManager(config, ["popen"])
assert gm.specs[0].chdir is None
def test_default_chdir(self, config) -> None:
def test_default_chdir(self, config: pytest.Config) -> None:
specs = ["ssh=noco", "socket=xyz"]
for spec in NodeManager(config, specs).specs:
assert spec.chdir == "pyexecnetcache"
@ -69,10 +72,13 @@ class TestNodeManagerPopen:
assert spec.chdir == "abc"
def test_popen_makegateway_events(
self, config, hookrecorder, workercontroller
self,
config: pytest.Config,
hookrecorder: pytest.HookRecorder,
workercontroller: None,
) -> None:
hm = NodeManager(config, ["popen"] * 2)
hm.setup_nodes(None)
hm.setup_nodes(None) # type: ignore[arg-type]
call = hookrecorder.popcall("pytest_xdist_setupnodes")
assert len(call.specs) == 2
@ -86,20 +92,24 @@ class TestNodeManagerPopen:
assert not len(hm.group)
def test_popens_rsync(
self, config, source: Path, dest: Path, workercontroller
self,
config: pytest.Config,
source: Path,
dest: Path,
workercontroller: None,
) -> None:
hm = NodeManager(config, ["popen"] * 2)
hm.setup_nodes(None)
hm.setup_nodes(None) # type: ignore[arg-type]
assert len(hm.group) == 2
for gw in hm.group:
class pseudoexec:
args = [] # type: ignore[var-annotated]
def __init__(self, *args):
def __init__(self, *args: object) -> None:
self.args.extend(args)
def waitclose(self):
def waitclose(self) -> None:
pass
gw.remote_exec = pseudoexec # type: ignore[assignment]
@ -112,10 +122,10 @@ class TestNodeManagerPopen:
assert "sys.path.insert" in gw.remote_exec.args[0] # type: ignore[attr-defined]
def test_rsync_popen_with_path(
self, config, source: Path, dest: Path, workercontroller
self, config: pytest.Config, source: Path, dest: Path, workercontroller: None
) -> None:
hm = NodeManager(config, ["popen//chdir=%s" % dest] * 1)
hm.setup_nodes(None)
hm.setup_nodes(None) # type: ignore[arg-type]
source.joinpath("dir1", "dir2").mkdir(parents=True)
source.joinpath("dir1", "dir2", "hello").touch()
notifications = []
@ -131,15 +141,15 @@ class TestNodeManagerPopen:
def test_rsync_same_popen_twice(
self,
config,
config: pytest.Config,
source: Path,
dest: Path,
hookrecorder,
workercontroller,
hookrecorder: pytest.HookRecorder,
workercontroller: None,
) -> None:
hm = NodeManager(config, ["popen//chdir=%s" % dest] * 2)
hm.roots = []
hm.setup_nodes(None)
hm.setup_nodes(None) # type: ignore[arg-type]
source.joinpath("dir1", "dir2").mkdir(parents=True)
source.joinpath("dir1", "dir2", "hello").touch()
gw = hm.group[0]
@ -200,7 +210,11 @@ class TestNodeManager:
assert p.joinpath("dir1", "file1").check()
def test_popen_rsync_subdir(
self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller
self,
pytester: pytest.Pytester,
source: Path,
dest: Path,
workercontroller: None,
) -> None:
dir1 = source / "dir1"
dir1.mkdir()
@ -214,7 +228,8 @@ class TestNodeManager:
"--tx", "popen//chdir=%s" % dest, "--rsyncdir", rsyncroot, source
)
)
nodemanager.setup_nodes(None) # calls .rsync_roots()
# calls .rsync_roots()
nodemanager.setup_nodes(None) # type: ignore[arg-type]
if rsyncroot == source:
dest = dest.joinpath("source")
assert dest.joinpath("dir1").exists()
@ -223,14 +238,19 @@ class TestNodeManager:
nodemanager.teardown_nodes()
@pytest.mark.parametrize(
"flag, expects_report", [("-q", False), ("", False), ("-v", True)]
["flag", "expects_report"],
[
("-q", False),
("", False),
("-v", True),
],
)
def test_rsync_report(
self,
pytester: pytest.Pytester,
source: Path,
dest: Path,
workercontroller,
workercontroller: None,
capsys: pytest.CaptureFixture[str],
flag: str,
expects_report: bool,
@ -241,7 +261,8 @@ class TestNodeManager:
if flag:
args.append(flag)
nodemanager = NodeManager(pytester.parseconfig(*args))
nodemanager.setup_nodes(None) # calls .rsync_roots()
# calls .rsync_roots()
nodemanager.setup_nodes(None) # type: ignore[arg-type]
out, _ = capsys.readouterr()
if expects_report:
assert "<= pytest/__init__.py" in out
@ -249,7 +270,11 @@ class TestNodeManager:
assert "<= pytest/__init__.py" not in out
def test_init_rsync_roots(
self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller
self,
pytester: pytest.Pytester,
source: Path,
dest: Path,
workercontroller: None,
) -> None:
dir2 = source.joinpath("dir1", "dir2")
dir2.mkdir(parents=True)
@ -267,13 +292,18 @@ class TestNodeManager:
)
config = pytester.parseconfig(source)
nodemanager = NodeManager(config, ["popen//chdir=%s" % dest])
nodemanager.setup_nodes(None) # calls .rsync_roots()
# calls .rsync_roots()
nodemanager.setup_nodes(None) # type: ignore[arg-type]
assert dest.joinpath("dir2").exists()
assert not dest.joinpath("dir1").exists()
assert not dest.joinpath("bogus").exists()
def test_rsyncignore(
self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller
self,
pytester: pytest.Pytester,
source: Path,
dest: Path,
workercontroller: None,
) -> None:
dir2 = source.joinpath("dir1", "dir2")
dir2.mkdir(parents=True)
@ -297,7 +327,8 @@ class TestNodeManager:
config = pytester.parseconfig(source)
config.option.rsyncignore = ["bar"]
nodemanager = NodeManager(config, ["popen//chdir=%s" % dest])
nodemanager.setup_nodes(None) # calls .rsync_roots()
# calls .rsync_roots()
nodemanager.setup_nodes(None) # type: ignore[arg-type]
assert dest.joinpath("dir1").exists()
assert not dest.joinpath("dir1", "dir2").exists()
assert dest.joinpath("dir5", "file").exists()
@ -306,14 +337,19 @@ class TestNodeManager:
assert not dest.joinpath("bar").exists()
def test_optimise_popen(
self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller
self,
pytester: pytest.Pytester,
source: Path,
dest: Path,
workercontroller: None,
) -> None:
specs = ["popen"] * 3
source.joinpath("conftest.py").write_text("rsyncdirs = ['a']")
source.joinpath("a").mkdir()
config = pytester.parseconfig(source)
nodemanager = NodeManager(config, specs)
nodemanager.setup_nodes(None) # calls .rysnc_roots()
# calls .rysnc_roots()
nodemanager.setup_nodes(None) # type: ignore[arg-type]
for gwspec in nodemanager.specs:
assert gwspec._samefilesystem()
assert not gwspec.chdir
@ -349,7 +385,7 @@ class MyWarning(UserWarning):
),
],
)
def test_unserialize_warning_msg(w_cls):
def test_unserialize_warning_msg(w_cls: type[Warning] | str) -> None:
"""Test that warning serialization process works well."""
# Create a test warning message
with pytest.warns(UserWarning) as w:
@ -390,7 +426,7 @@ class MyWarningUnknown(UserWarning):
__module__ = "unknown"
def test_warning_serialization_tweaked_module():
def test_warning_serialization_tweaked_module() -> None:
"""Test for GH#404."""
# Create a test warning message
with pytest.warns(UserWarning) as w:

View File

@ -5,5 +5,5 @@ class MyWarning2(UserWarning):
pass
def generate_warning():
def generate_warning() -> None:
warnings.warn(MyWarning2("hello"))