Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse.events. #10998

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10998.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.events`.
5 changes: 1 addition & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ files =
synapse/config,
synapse/crypto,
synapse/event_auth.py,
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/validator.py,
synapse/events,
synapse/federation,
synapse/groups,
synapse/handlers,
Expand Down
121 changes: 61 additions & 60 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import abc
import os
from typing import Dict, Optional, Tuple, Type
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

from unpaddedbase64 import encode_base64

Expand Down Expand Up @@ -86,7 +86,7 @@ class DefaultDictProperty(DictProperty):

__slots__ = ["default"]

def __init__(self, key, default):
def __init__(self, key: str, default: Any):
super().__init__(key)
self.default = default

Expand All @@ -111,22 +111,25 @@ def __init__(self, internal_metadata_dict: JsonDict):
# in the DAG)
self.outlier = False

out_of_band_membership: bool = DictProperty("out_of_band_membership")
send_on_behalf_of: str = DictProperty("send_on_behalf_of")
recheck_redaction: bool = DictProperty("recheck_redaction")
soft_failed: bool = DictProperty("soft_failed")
proactively_send: bool = DictProperty("proactively_send")
redacted: bool = DictProperty("redacted")
txn_id: str = DictProperty("txn_id")
token_id: int = DictProperty("token_id")
historical: bool = DictProperty("historical")
# Tell mypy to ignore the types of properties defined by DictProperty until
# that is properly annotated.

out_of_band_membership: bool = DictProperty("out_of_band_membership") # type: ignore[assignment]
send_on_behalf_of: str = DictProperty("send_on_behalf_of") # type: ignore[assignment]
recheck_redaction: bool = DictProperty("recheck_redaction") # type: ignore[assignment]
soft_failed: bool = DictProperty("soft_failed") # type: ignore[assignment]
proactively_send: bool = DictProperty("proactively_send") # type: ignore[assignment]
redacted: bool = DictProperty("redacted") # type: ignore[assignment]
txn_id: str = DictProperty("txn_id") # type: ignore[assignment]
token_id: int = DictProperty("token_id") # type: ignore[assignment]
historical: bool = DictProperty("historical") # type: ignore[assignment]

# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
before: RoomStreamToken = DictProperty("before")
after: RoomStreamToken = DictProperty("after")
order: Tuple[int, int] = DictProperty("order")
before: RoomStreamToken = DictProperty("before") # type: ignore[assignment]
after: RoomStreamToken = DictProperty("after") # type: ignore[assignment]
order: Tuple[int, int] = DictProperty("order") # type: ignore[assignment]

def get_dict(self) -> JsonDict:
return dict(self._dict)
Expand Down Expand Up @@ -162,9 +165,6 @@ def need_to_check_redaction(self) -> bool:

If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.

Returns:
bool
"""
return self._dict.get("recheck_redaction", False)

Expand All @@ -176,32 +176,23 @@ def is_soft_failed(self) -> bool:
sent to clients.
2. They should not be added to the forward extremities (and
therefore not to current state).

Returns:
bool
"""
return self._dict.get("soft_failed", False)

def should_proactively_send(self):
def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and
servers.

This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event.

Returns:
bool
"""
return self._dict.get("proactively_send", True)

def is_redacted(self):
def is_redacted(self) -> bool:
"""Whether the event has been redacted.

This is used for efficiently checking whether an event has been
marked as redacted without needing to make another database call.

Returns:
bool
"""
return self._dict.get("redacted", False)

Expand Down Expand Up @@ -241,29 +232,37 @@ def __init__(

self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)

auth_events = DictProperty("auth_events")
depth = DictProperty("depth")
content = DictProperty("content")
hashes = DictProperty("hashes")
origin = DictProperty("origin")
origin_server_ts = DictProperty("origin_server_ts")
prev_events = DictProperty("prev_events")
redacts = DefaultDictProperty("redacts", None)
room_id = DictProperty("room_id")
sender = DictProperty("sender")
state_key = DictProperty("state_key")
type = DictProperty("type")
user_id = DictProperty("sender")
# Tell mypy to ignore the types of properties defined by DictProperty until
# that is properly annotated.
#
# Note that auth_events, prev_events differ based on the sub-classes of
# EventBase.
#
# TODO: Add a type for state_key.

auth_events = DictProperty("auth_events") # type: ignore[assignment]
depth: int = DictProperty("depth") # type: ignore[assignment]
content: JsonDict = DictProperty("content") # type: ignore[assignment]
hashes: Dict[str, str] = DictProperty("hashes") # type: ignore[assignment]
origin: str = DictProperty("origin") # type: ignore[assignment]
origin_server_ts: int = DictProperty("origin_server_ts") # type: ignore[assignment]
prev_events = DictProperty("prev_events") # type: ignore[assignment]
redacts: Optional[str] = DefaultDictProperty("redacts", None) # type: ignore[assignment]
room_id: str = DictProperty("room_id") # type: ignore[assignment]
sender: str = DictProperty("sender") # type: ignore[assignment]
state_key = DictProperty("state_key") # type: ignore[assignment]
type: str = DictProperty("type") # type: ignore[assignment]
user_id: str = DictProperty("sender") # type: ignore[assignment]

@property
def event_id(self) -> str:
raise NotImplementedError()

@property
def membership(self):
def membership(self) -> str:
return self.content["membership"]

def is_state(self):
def is_state(self) -> bool:
return hasattr(self, "state_key") and self.state_key is not None

def get_dict(self) -> JsonDict:
Expand All @@ -272,10 +271,10 @@ def get_dict(self) -> JsonDict:

return d

def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Any:
return self._dict.get(key, default)

def get_internal_metadata_dict(self):
def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict()

def get_pdu_json(self, time_now=None) -> JsonDict:
Expand Down Expand Up @@ -305,49 +304,49 @@ def get_templated_pdu_json(self) -> JsonDict:

return template_json

def __set__(self, instance, value):
def __set__(self, instance, value) -> None:
raise AttributeError("Unrecognized attribute %s" % (instance,))

def __getitem__(self, field):
def __getitem__(self, field: str) -> Optional[Any]:
return self._dict[field]

def __contains__(self, field):
def __contains__(self, field: str) -> bool:
return field in self._dict

def items(self):
def items(self) -> List[Tuple[str, Optional[Any]]]:
return list(self._dict.items())

def keys(self):
def keys(self) -> Iterable[str]:
return self._dict.keys()

def prev_event_ids(self):
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.

Returns:
list[str]: The list of event IDs of this event's prev_events
The list of event IDs of this event's prev_events
"""
return [e for e, _ in self.prev_events]

def auth_event_ids(self):
def auth_event_ids(self) -> List[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.

Returns:
list[str]: The list of event IDs of this event's auth_events
The list of event IDs of this event's auth_events
"""
return [e for e, _ in self.auth_events]

def freeze(self):
def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident"""

# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)

def __str__(self):
def __str__(self) -> str:
return self.__repr__()

def __repr__(self):
def __repr__(self) -> str:
return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
self.__class__.__name__,
self.event_id,
Expand Down Expand Up @@ -439,7 +438,7 @@ def __init__(
else:
frozen_dict = event_dict

self._event_id = None
self._event_id: Optional[str] = None

super().__init__(
frozen_dict,
Expand Down Expand Up @@ -499,12 +498,14 @@ def event_id(self):
return self._event_id


def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
def _event_type_from_format_version(
format_version: int,
) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
"""Returns the python type to use to construct an Event object for the
given event format version.

Args:
format_version (int): The event format version
format_version: The event format version

Returns:
type: A type that can be initialized as per the initializer of
Expand Down
4 changes: 2 additions & 2 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ class EventBuilder:
)

@property
def state_key(self):
def state_key(self) -> str:
if self._state_key is not None:
return self._state_key

raise AttributeError("state_key")

def is_state(self):
def is_state(self) -> bool:
return self._state_key is not None

async def build(
Expand Down
16 changes: 8 additions & 8 deletions synapse/events/presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Expand All @@ -33,14 +34,13 @@
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
]
GET_INTERESTED_USERS_CALLBACK = Callable[
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]

logger = logging.getLogger(__name__)


def load_legacy_presence_router(hs: "HomeServer"):
def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
"""
Expand Down Expand Up @@ -69,7 +69,7 @@ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
if f is None:
return None

def run(*args, **kwargs):
def run(*args: Any, **kwargs: Any) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
Expand Down Expand Up @@ -104,7 +104,7 @@ def register_presence_router_callbacks(
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
):
) -> None:
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users]
Expand Down Expand Up @@ -142,7 +142,7 @@ async def get_users_for_states(
# Don't include any extra destinations for presence updates
return {}

users_for_states = {}
users_for_states: Dict[str, Set[UserPresenceState]] = {}
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
try:
Expand Down Expand Up @@ -171,7 +171,7 @@ async def get_users_for_states(

return users_for_states

async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
Expand Down
Loading