Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 211 additions & 33 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,40 @@
TypedDict,
)

from librt.internal import cache_version
from librt.internal import (
cache_version,
read_bool,
read_int as read_int_bare,
read_str as read_str_bare,
read_tag,
write_bool,
write_int as write_int_bare,
write_str as write_str_bare,
write_tag,
)

import mypy.semanal_main
from mypy.cache import (
CACHE_VERSION,
DICT_STR_GEN,
LITERAL_NONE,
CacheMeta,
ReadBuffer,
SerializedError,
Tag,
WriteBuffer,
write_json,
read_int,
read_int_list,
read_int_opt,
read_str,
read_str_list,
read_str_opt,
write_int,
write_int_list,
write_int_opt,
write_str,
write_str_list,
write_str_opt,
)
from mypy.checker import TypeChecker
from mypy.defaults import (
Expand All @@ -62,7 +86,7 @@
from mypy.errors import CompileError, ErrorInfo, Errors, ErrorTuple, report_internal_error
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.indirection import TypeIndirectionVisitor
from mypy.ipc import BadStatus, IPCClient, read_status, ready_to_read, receive, send
from mypy.ipc import BadStatus, IPCClient, IPCMessage, read_status, ready_to_read, receive, send
from mypy.messages import MessageBuilder
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable
from mypy.partially_defined import PossiblyUndefinedVariableVisitor
Expand Down Expand Up @@ -310,7 +334,10 @@ def default_flush_errors(
WorkerClient(f".mypy_worker.{idx}.json", options_data, worker_env or os.environ)
for idx in range(options.num_workers)
]
sources_data = sources_to_bytes(sources)
sources_message = SourcesDataMessage(sources=sources)
buf = WriteBuffer()
sources_message.write(buf)
sources_data = buf.getvalue()
for worker in workers:
# Start loading graph in each worker as soon as it is up.
worker.connect()
Expand Down Expand Up @@ -342,7 +369,7 @@ def default_flush_errors(
finally:
for worker in workers:
try:
send(worker.conn, {"final": True})
send(worker.conn, SccRequestMessage(scc_id=None))
except OSError:
pass
for worker in workers:
Expand Down Expand Up @@ -1049,7 +1076,7 @@ def submit_to_workers(self, sccs: list[SCC] | None = None) -> None:
while self.scc_queue and self.free_workers:
idx = self.free_workers.pop()
_, _, scc = heappop(self.scc_queue)
send(self.workers[idx].conn, {"scc_id": scc.id})
send(self.workers[idx].conn, SccRequestMessage(scc_id=scc.id))

def wait_for_done(
self, graph: Graph
Expand Down Expand Up @@ -1077,15 +1104,13 @@ def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, l
done_sccs = []
results = {}
for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT):
data = receive(self.workers[idx].conn)
data = SccResponseMessage.read(receive(self.workers[idx].conn))
self.free_workers.add(idx)
scc_id = data["scc_id"]
if "blocker" in data:
blocker = data["blocker"]
raise CompileError(
blocker["messages"], blocker["use_stdout"], blocker["module_with_blocker"]
)
results.update({k: tuple(v) for k, v in data["result"].items()})
scc_id = data.scc_id
if data.blocker is not None:
raise data.blocker
assert data.result is not None
results.update(data.result)
done_sccs.append(self.scc_by_id[scc_id])
self.submit_to_workers() # advance after some workers are free.
return (
Expand Down Expand Up @@ -3558,14 +3583,15 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
manager.top_order = [scc.id for scc in sccs]

# Broadcast SCC structure to the parallel workers, since they don't compute it.
sccs_data = sccs_to_bytes(sccs)
sccs_message = SccsDataMessage(sccs=sccs)
buf = WriteBuffer()
sccs_message.write(buf)
sccs_data = buf.getvalue()
for worker in manager.workers:
data = receive(worker.conn)
assert data["status"] == "ok"
AckMessage.read(receive(worker.conn))
worker.conn.write_bytes(sccs_data)
for worker in manager.workers:
data = receive(worker.conn)
assert data["status"] == "ok"
AckMessage.read(receive(worker.conn))

manager.free_workers = set(range(manager.options.num_workers))

Expand Down Expand Up @@ -3944,20 +3970,6 @@ def write_undocumented_ref_info(
metastore.write(ref_info_file, json_dumps(deps_json))


def sources_to_bytes(sources: list[BuildSource]) -> bytes:
source_tuples = [(s.path, s.module, s.text, s.base_dir, s.followed) for s in sources]
buf = WriteBuffer()
write_json(buf, {"sources": source_tuples})
return buf.getvalue()


def sccs_to_bytes(sccs: list[SCC]) -> bytes:
scc_tuples = [(list(scc.mod_ids), scc.id, list(scc.deps)) for scc in sccs]
buf = WriteBuffer()
write_json(buf, {"sccs": scc_tuples})
return buf.getvalue()


def serialize_codes(errs: list[ErrorTuple]) -> list[SerializedError]:
return [
(path, line, column, end_line, end_column, severity, message, code.code if code else None)
Expand All @@ -3979,3 +3991,169 @@ def deserialize_codes(errs: list[SerializedError]) -> list[ErrorTuple]:
)
for path, line, column, end_line, end_column, severity, message, code in errs
]


# The IPC message classes and tags for communication with build workers are
# in this file to avoid import cycles.
# Note that we use a more compact fixed serialization format than in cache.py.
# This is because the messages don't need to read by a generic tool, nor there
# is any need for backwards compatibility. We still reuse some elements from
# cache.py for convenience, and also some conventions (like using bare ints
# to specify object size).
# Note that we can use tags overlapping with cache.py, since they should never
# appear on the same context.
ACK_MESSAGE: Final[Tag] = 101
SCC_REQUEST_MESSAGE: Final[Tag] = 102
SCC_RESPONSE_MESSAGE: Final[Tag] = 103
SOURCES_DATA_MESSAGE: Final[Tag] = 104
SCCS_DATA_MESSAGE: Final[Tag] = 105


class AckMessage(IPCMessage):
"""An empty message used primarily for synchronization."""

@classmethod
def read(cls, buf: ReadBuffer) -> AckMessage:
assert read_tag(buf) == ACK_MESSAGE
return AckMessage()

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, ACK_MESSAGE)


class SccRequestMessage(IPCMessage):
"""
A message representing a request to type check an SCC.

If scc_id is None, then it means that the coordinator requested a shutdown.
"""

def __init__(self, *, scc_id: int | None) -> None:
self.scc_id = scc_id

@classmethod
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
assert read_tag(buf) == SCC_REQUEST_MESSAGE
return SccRequestMessage(scc_id=read_int_opt(buf))

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, SCC_REQUEST_MESSAGE)
write_int_opt(buf, self.scc_id)


class SccResponseMessage(IPCMessage):
"""
A message representing a result of type checking an SCC.

Only one of `result` or `blocker` can be non-None. The latter means there was
a blocking error while type checking the SCC.
"""

def __init__(
self,
*,
scc_id: int,
result: dict[str, tuple[str, list[str]]] | None = None,
blocker: CompileError | None = None,
) -> None:
if result is not None:
assert blocker is None
if blocker is not None:
assert result is None
self.scc_id = scc_id
self.result = result
self.blocker = blocker

@classmethod
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
scc_id = read_int(buf)
tag = read_tag(buf)
if tag == LITERAL_NONE:
return SccResponseMessage(
scc_id=scc_id,
blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)),
)
else:
assert tag == DICT_STR_GEN
return SccResponseMessage(
scc_id=scc_id,
result={
read_str_bare(buf): (read_str(buf), read_str_list(buf))
for _ in range(read_int_bare(buf))
},
)

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, SCC_RESPONSE_MESSAGE)
write_int(buf, self.scc_id)
if self.result is None:
assert self.blocker is not None
write_tag(buf, LITERAL_NONE)
write_str_list(buf, self.blocker.messages)
write_bool(buf, self.blocker.use_stdout)
write_str_opt(buf, self.blocker.module_with_blocker)
else:
write_tag(buf, DICT_STR_GEN)
write_int_bare(buf, len(self.result))
for mod_id in sorted(self.result):
write_str_bare(buf, mod_id)
hex_hash, errs = self.result[mod_id]
write_str(buf, hex_hash)
write_str_list(buf, errs)


class SourcesDataMessage(IPCMessage):
"""A message wrapping a list of build sources."""

def __init__(self, *, sources: list[BuildSource]) -> None:
self.sources = sources

@classmethod
def read(cls, buf: ReadBuffer) -> SourcesDataMessage:
assert read_tag(buf) == SOURCES_DATA_MESSAGE
sources = [
BuildSource(
read_str_opt(buf),
read_str_opt(buf),
read_str_opt(buf),
read_str_opt(buf),
read_bool(buf),
)
for _ in range(read_int_bare(buf))
]
return SourcesDataMessage(sources=sources)

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, SOURCES_DATA_MESSAGE)
write_int_bare(buf, len(self.sources))
for bs in self.sources:
write_str_opt(buf, bs.path)
write_str_opt(buf, bs.module)
write_str_opt(buf, bs.text)
write_str_opt(buf, bs.base_dir)
write_bool(buf, bs.followed)


class SccsDataMessage(IPCMessage):
"""A message wrapping the SCC structure computed by the coordinator."""

def __init__(self, *, sccs: list[SCC]) -> None:
self.sccs = sccs

@classmethod
def read(cls, buf: ReadBuffer) -> SccsDataMessage:
assert read_tag(buf) == SCCS_DATA_MESSAGE
sccs = [
SCC(set(read_str_list(buf)), read_int(buf), read_int_list(buf))
for _ in range(read_int_bare(buf))
]
return SccsDataMessage(sccs=sccs)

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, SCCS_DATA_MESSAGE)
write_int_bare(buf, len(self.sccs))
for scc in self.sccs:
write_str_list(buf, sorted(scc.mod_ids))
write_int(buf, scc.id)
write_int_list(buf, sorted(scc.deps))
Loading