From 81679251ef0972bda16e601a6a014a9dc5d304a3 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 11 Jan 2026 22:47:17 +0000 Subject: [PATCH] Use proper fixed format for parallel type checking IPC --- mypy/build.py | 244 +++++++++++++++++++++++++++++++----- mypy/build_worker/worker.py | 44 +++---- mypy/ipc.py | 35 +++--- 3 files changed, 254 insertions(+), 69 deletions(-) diff --git a/mypy/build.py b/mypy/build.py index 9b6e4d54d1b1..d4038cd5c0aa 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -40,16 +40,40 @@ TypedDict, ) -from librt.internal import cache_version +from librt.internal import ( + cache_version, + read_bool, + read_int as read_int_bare, + read_str as read_str_bare, + read_tag, + write_bool, + write_int as write_int_bare, + write_str as write_str_bare, + write_tag, +) import mypy.semanal_main from mypy.cache import ( CACHE_VERSION, + DICT_STR_GEN, + LITERAL_NONE, CacheMeta, ReadBuffer, SerializedError, + Tag, WriteBuffer, - write_json, + read_int, + read_int_list, + read_int_opt, + read_str, + read_str_list, + read_str_opt, + write_int, + write_int_list, + write_int_opt, + write_str, + write_str_list, + write_str_opt, ) from mypy.checker import TypeChecker from mypy.defaults import ( @@ -62,7 +86,7 @@ from mypy.errors import CompileError, ErrorInfo, Errors, ErrorTuple, report_internal_error from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.indirection import TypeIndirectionVisitor -from mypy.ipc import BadStatus, IPCClient, read_status, ready_to_read, receive, send +from mypy.ipc import BadStatus, IPCClient, IPCMessage, read_status, ready_to_read, receive, send from mypy.messages import MessageBuilder from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable from mypy.partially_defined import PossiblyUndefinedVariableVisitor @@ -310,7 +334,10 @@ def default_flush_errors( WorkerClient(f".mypy_worker.{idx}.json", options_data, worker_env or os.environ) for idx in range(options.num_workers) ] - sources_data = sources_to_bytes(sources) + sources_message = SourcesDataMessage(sources=sources) + buf = WriteBuffer() + sources_message.write(buf) + sources_data = buf.getvalue() for worker in workers: # Start loading graph in each worker as soon as it is up. worker.connect() @@ -342,7 +369,7 @@ def default_flush_errors( finally: for worker in workers: try: - send(worker.conn, {"final": True}) + send(worker.conn, SccRequestMessage(scc_id=None)) except OSError: pass for worker in workers: @@ -1049,7 +1076,7 @@ def submit_to_workers(self, sccs: list[SCC] | None = None) -> None: while self.scc_queue and self.free_workers: idx = self.free_workers.pop() _, _, scc = heappop(self.scc_queue) - send(self.workers[idx].conn, {"scc_id": scc.id}) + send(self.workers[idx].conn, SccRequestMessage(scc_id=scc.id)) def wait_for_done( self, graph: Graph @@ -1077,15 +1104,13 @@ def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, l done_sccs = [] results = {} for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT): - data = receive(self.workers[idx].conn) + data = SccResponseMessage.read(receive(self.workers[idx].conn)) self.free_workers.add(idx) - scc_id = data["scc_id"] - if "blocker" in data: - blocker = data["blocker"] - raise CompileError( - blocker["messages"], blocker["use_stdout"], blocker["module_with_blocker"] - ) - results.update({k: tuple(v) for k, v in data["result"].items()}) + scc_id = data.scc_id + if data.blocker is not None: + raise data.blocker + assert data.result is not None + results.update(data.result) done_sccs.append(self.scc_by_id[scc_id]) self.submit_to_workers() # advance after some workers are free. return ( @@ -3558,14 +3583,15 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: manager.top_order = [scc.id for scc in sccs] # Broadcast SCC structure to the parallel workers, since they don't compute it. - sccs_data = sccs_to_bytes(sccs) + sccs_message = SccsDataMessage(sccs=sccs) + buf = WriteBuffer() + sccs_message.write(buf) + sccs_data = buf.getvalue() for worker in manager.workers: - data = receive(worker.conn) - assert data["status"] == "ok" + AckMessage.read(receive(worker.conn)) worker.conn.write_bytes(sccs_data) for worker in manager.workers: - data = receive(worker.conn) - assert data["status"] == "ok" + AckMessage.read(receive(worker.conn)) manager.free_workers = set(range(manager.options.num_workers)) @@ -3944,20 +3970,6 @@ def write_undocumented_ref_info( metastore.write(ref_info_file, json_dumps(deps_json)) -def sources_to_bytes(sources: list[BuildSource]) -> bytes: - source_tuples = [(s.path, s.module, s.text, s.base_dir, s.followed) for s in sources] - buf = WriteBuffer() - write_json(buf, {"sources": source_tuples}) - return buf.getvalue() - - -def sccs_to_bytes(sccs: list[SCC]) -> bytes: - scc_tuples = [(list(scc.mod_ids), scc.id, list(scc.deps)) for scc in sccs] - buf = WriteBuffer() - write_json(buf, {"sccs": scc_tuples}) - return buf.getvalue() - - def serialize_codes(errs: list[ErrorTuple]) -> list[SerializedError]: return [ (path, line, column, end_line, end_column, severity, message, code.code if code else None) @@ -3979,3 +3991,169 @@ def deserialize_codes(errs: list[SerializedError]) -> list[ErrorTuple]: ) for path, line, column, end_line, end_column, severity, message, code in errs ] + + +# The IPC message classes and tags for communication with build workers are +# in this file to avoid import cycles. +# Note that we use a more compact fixed serialization format than in cache.py. +# This is because the messages don't need to read by a generic tool, nor there +# is any need for backwards compatibility. We still reuse some elements from +# cache.py for convenience, and also some conventions (like using bare ints +# to specify object size). +# Note that we can use tags overlapping with cache.py, since they should never +# appear on the same context. +ACK_MESSAGE: Final[Tag] = 101 +SCC_REQUEST_MESSAGE: Final[Tag] = 102 +SCC_RESPONSE_MESSAGE: Final[Tag] = 103 +SOURCES_DATA_MESSAGE: Final[Tag] = 104 +SCCS_DATA_MESSAGE: Final[Tag] = 105 + + +class AckMessage(IPCMessage): + """An empty message used primarily for synchronization.""" + + @classmethod + def read(cls, buf: ReadBuffer) -> AckMessage: + assert read_tag(buf) == ACK_MESSAGE + return AckMessage() + + def write(self, buf: WriteBuffer) -> None: + write_tag(buf, ACK_MESSAGE) + + +class SccRequestMessage(IPCMessage): + """ + A message representing a request to type check an SCC. + + If scc_id is None, then it means that the coordinator requested a shutdown. + """ + + def __init__(self, *, scc_id: int | None) -> None: + self.scc_id = scc_id + + @classmethod + def read(cls, buf: ReadBuffer) -> SccRequestMessage: + assert read_tag(buf) == SCC_REQUEST_MESSAGE + return SccRequestMessage(scc_id=read_int_opt(buf)) + + def write(self, buf: WriteBuffer) -> None: + write_tag(buf, SCC_REQUEST_MESSAGE) + write_int_opt(buf, self.scc_id) + + +class SccResponseMessage(IPCMessage): + """ + A message representing a result of type checking an SCC. + + Only one of `result` or `blocker` can be non-None. The latter means there was + a blocking error while type checking the SCC. + """ + + def __init__( + self, + *, + scc_id: int, + result: dict[str, tuple[str, list[str]]] | None = None, + blocker: CompileError | None = None, + ) -> None: + if result is not None: + assert blocker is None + if blocker is not None: + assert result is None + self.scc_id = scc_id + self.result = result + self.blocker = blocker + + @classmethod + def read(cls, buf: ReadBuffer) -> SccResponseMessage: + assert read_tag(buf) == SCC_RESPONSE_MESSAGE + scc_id = read_int(buf) + tag = read_tag(buf) + if tag == LITERAL_NONE: + return SccResponseMessage( + scc_id=scc_id, + blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)), + ) + else: + assert tag == DICT_STR_GEN + return SccResponseMessage( + scc_id=scc_id, + result={ + read_str_bare(buf): (read_str(buf), read_str_list(buf)) + for _ in range(read_int_bare(buf)) + }, + ) + + def write(self, buf: WriteBuffer) -> None: + write_tag(buf, SCC_RESPONSE_MESSAGE) + write_int(buf, self.scc_id) + if self.result is None: + assert self.blocker is not None + write_tag(buf, LITERAL_NONE) + write_str_list(buf, self.blocker.messages) + write_bool(buf, self.blocker.use_stdout) + write_str_opt(buf, self.blocker.module_with_blocker) + else: + write_tag(buf, DICT_STR_GEN) + write_int_bare(buf, len(self.result)) + for mod_id in sorted(self.result): + write_str_bare(buf, mod_id) + hex_hash, errs = self.result[mod_id] + write_str(buf, hex_hash) + write_str_list(buf, errs) + + +class SourcesDataMessage(IPCMessage): + """A message wrapping a list of build sources.""" + + def __init__(self, *, sources: list[BuildSource]) -> None: + self.sources = sources + + @classmethod + def read(cls, buf: ReadBuffer) -> SourcesDataMessage: + assert read_tag(buf) == SOURCES_DATA_MESSAGE + sources = [ + BuildSource( + read_str_opt(buf), + read_str_opt(buf), + read_str_opt(buf), + read_str_opt(buf), + read_bool(buf), + ) + for _ in range(read_int_bare(buf)) + ] + return SourcesDataMessage(sources=sources) + + def write(self, buf: WriteBuffer) -> None: + write_tag(buf, SOURCES_DATA_MESSAGE) + write_int_bare(buf, len(self.sources)) + for bs in self.sources: + write_str_opt(buf, bs.path) + write_str_opt(buf, bs.module) + write_str_opt(buf, bs.text) + write_str_opt(buf, bs.base_dir) + write_bool(buf, bs.followed) + + +class SccsDataMessage(IPCMessage): + """A message wrapping the SCC structure computed by the coordinator.""" + + def __init__(self, *, sccs: list[SCC]) -> None: + self.sccs = sccs + + @classmethod + def read(cls, buf: ReadBuffer) -> SccsDataMessage: + assert read_tag(buf) == SCCS_DATA_MESSAGE + sccs = [ + SCC(set(read_str_list(buf)), read_int(buf), read_int_list(buf)) + for _ in range(read_int_bare(buf)) + ] + return SccsDataMessage(sccs=sccs) + + def write(self, buf: WriteBuffer) -> None: + write_tag(buf, SCCS_DATA_MESSAGE) + write_int_bare(buf, len(self.sccs)) + for scc in self.sccs: + write_str_list(buf, sorted(scc.mod_ids)) + write_int(buf, scc.id) + write_int_list(buf, sorted(scc.deps)) diff --git a/mypy/build_worker/worker.py b/mypy/build_worker/worker.py index 3af34411b729..049f5e44256a 100644 --- a/mypy/build_worker/worker.py +++ b/mypy/build_worker/worker.py @@ -5,10 +5,10 @@ * Read (pickled) build options from command line. * Populate status file with pid and socket address. * Receive build sources from coordinator. -* Load graph using the sources, and send "ok" to coordinator. -* Receive SCC structure from coordinator, and ack it with an "ok". +* Load graph using the sources, and send ack to coordinator. +* Receive SCC structure from coordinator, and ack it. * Receive an SCC id from coordinator, process it, and send back the results. -* When prompted by coordinator (with a "final" message), cleanup and shutdown. +* When prompted by coordinator (with a scc_id=None message), cleanup and shutdown. """ from __future__ import annotations @@ -25,7 +25,17 @@ from typing import NamedTuple from mypy import util -from mypy.build import SCC, BuildManager, load_graph, load_plugins, process_stale_scc +from mypy.build import ( + AckMessage, + BuildManager, + SccRequestMessage, + SccResponseMessage, + SccsDataMessage, + SourcesDataMessage, + load_graph, + load_plugins, + process_stale_scc, +) from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT from mypy.errors import CompileError, Errors, report_internal_error from mypy.fscache import FileSystemCache @@ -95,8 +105,7 @@ def main(argv: list[str]) -> None: def serve(server: IPCServer, ctx: ServerContext) -> None: - data = receive(server) - sources = [BuildSource(*st) for st in data["sources"]] + sources = SourcesDataMessage.read(receive(server)).sources manager = setup_worker_manager(sources, ctx) if manager is None: return @@ -117,35 +126,28 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: manager.import_map[id] = set(graph[id].dependencies + graph[id].suppressed) # Notify worker we are done loading graph. - send(server, {"status": "ok"}) - data = receive(server) - sccs = [SCC(set(mod_ids), scc_id, deps) for (mod_ids, scc_id, deps) in data["sccs"]] + send(server, AckMessage()) + sccs = SccsDataMessage.read(receive(server)).sccs manager.scc_by_id = {scc.id: scc for scc in sccs} manager.top_order = [scc.id for scc in sccs] # Notify coordinator we are ready to process SCCs. - send(server, {"status": "ok"}) + send(server, AckMessage()) while True: - data = receive(server) - if "final" in data: + scc_id = SccRequestMessage.read(receive(server)).scc_id + if scc_id is None: manager.dump_stats() break - scc_id = data["scc_id"] scc = manager.scc_by_id[scc_id] t0 = time.time() try: result = process_stale_scc(graph, scc, manager) # We must commit after each SCC, otherwise we break --sqlite-cache. manager.metastore.commit() - except CompileError as e: - blocker = { - "messages": e.messages, - "use_stdout": e.use_stdout, - "module_with_blocker": e.module_with_blocker, - } - send(server, {"scc_id": scc_id, "blocker": blocker}) + except CompileError as blocker: + send(server, SccResponseMessage(scc_id=scc_id, blocker=blocker)) else: - send(server, {"scc_id": scc_id, "result": result}) + send(server, SccResponseMessage(scc_id=scc_id, result=result)) manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1) diff --git a/mypy/ipc.py b/mypy/ipc.py index 3c10acc2b732..f3b250711181 100644 --- a/mypy/ipc.py +++ b/mypy/ipc.py @@ -13,15 +13,15 @@ import shutil import sys import tempfile +from abc import abstractmethod from collections.abc import Callable from select import select from types import TracebackType -from typing import Any, Final +from typing import Final +from typing_extensions import Self from librt.internal import ReadBuffer, WriteBuffer -from mypy.cache import read_json, write_json - if sys.platform == "win32": # This may be private, but it is needed for IPC on Windows, and is basically stable import _winapi @@ -366,29 +366,34 @@ def ready_to_read(conns: list[IPCClient], timeout: float | None = None) -> list[ return [connections.index(r) for r in ready] -# TODO: switch send() and receive() to proper fixed binary format. -def send(connection: IPCBase, data: dict[str, Any]) -> None: +def send(connection: IPCBase, data: IPCMessage) -> None: """Send data to a connection encoded and framed. - The data must be a JSON object. We assume that a single send call is a + The data must be a non-abstract IPCMessage. We assume that a single send call is a single frame to be sent. """ buf = WriteBuffer() - write_json(buf, data) + data.write(buf) connection.write_bytes(buf.getvalue()) -def receive(connection: IPCBase) -> dict[str, Any]: - """Receive single JSON data frame from a connection. +def receive(connection: IPCBase) -> ReadBuffer: + """Receive single encoded IPCMessage frame from a connection. Raise OSError if the data received is not valid. """ bdata = connection.read_bytes() if not bdata: raise OSError("No data received") - try: - buf = ReadBuffer(bdata) - data = read_json(buf) - except Exception as e: - raise OSError("Data received is not valid JSON dict") from e - return data + return ReadBuffer(bdata) + + +class IPCMessage: + @classmethod + @abstractmethod + def read(cls, buf: ReadBuffer) -> Self: + raise NotImplementedError + + @abstractmethod + def write(self, buf: WriteBuffer) -> None: + raise NotImplementedError