Skip to content

Commit

Permalink
Type aiocoap api (#444)
Browse files Browse the repository at this point in the history
* Type aiocoap api

* Revert name change

* Fix issues
  • Loading branch information
MartinHjelmare authored Mar 5, 2022
1 parent a563561 commit 8295e49
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 48 deletions.
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true

[mypy-pytradfri.api.aiocoap_api]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true

[mypy-pytradfri.device.controller]
check_untyped_defs = true
disallow_incomplete_defs = true
Expand Down
98 changes: 56 additions & 42 deletions pytradfri/api/aiocoap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from enum import Enum
import json
import logging
from typing import Callable
from typing import Any, Union, cast

from aiocoap import Context, Message
from aiocoap.credentials import CredentialsMissingError
Expand All @@ -15,56 +17,73 @@
TimeoutError,
)
from aiocoap.numbers.codes import Code
from aiocoap.protocol import BlockwiseRequest

from ..command import Command
from ..error import ClientError, RequestTimeout, ServerError
from ..gateway import Gateway

_LOGGER = logging.getLogger(__name__)
_SENTINEL = object()


class UndefinedType(Enum):
"""Singleton type for use with not set sentinel values."""

_singleton = 0


_SENTINEL = UndefinedType._singleton # pylint: disable=protected-access


class APIFactory:
"""ApiFactory."""

def __init__(self, host: str, psk_id="pytradfri", psk=None, internal_create=None):
def __init__(
self,
host: str,
psk_id: str = "pytradfri",
psk: str | None = None,
internal_create: UndefinedType | None = None,
):
"""Create object of class."""
if internal_create is not _SENTINEL:
raise ValueError("Use APIFactory.init(…) to initialize APIFactory")

self._psk = psk
self._host = host
self._psk_id = psk_id
self._observations_err_callbacks: list[Callable] = []
self._protocol = None
self._observations_err_callbacks: list[Callable[[Exception], None]] = []
self._protocol: asyncio.Task | None = None
self._reset_lock = asyncio.Lock()
self._shutdown = False

@classmethod
async def init(cls, host, psk_id="pytradfri", psk=None) -> "APIFactory":
async def init(
cls, host: str, psk_id: str = "pytradfri", psk: str | None = None
) -> "APIFactory":
"""Initialize an APIFactory."""
instance = cls(host, psk_id=psk_id, psk=psk, internal_create=_SENTINEL)
if psk:
await instance._update_credentials()
return instance

@property
def psk_id(self):
def psk_id(self) -> str:
"""Return psk id."""
return self._psk_id

@property
def psk(self):
def psk(self) -> str | None:
"""Return psk."""
return self._psk

async def _get_protocol(self):
async def _get_protocol(self) -> Context:
"""Get the protocol for the request."""
if self._protocol is None:
self._protocol = asyncio.create_task(Context.create_client_context())
return await self._protocol

async def _reset_protocol(self, exc=None):
async def _reset_protocol(self, exc: BaseException | None = None) -> None:
"""Reset the protocol if an error occurs."""
skip = self._reset_lock.locked()
async with self._reset_lock:
Expand All @@ -91,20 +110,22 @@ async def _reset_protocol(self, exc=None):
# Clear the saved callbacks
self._observations_err_callbacks.clear()

async def shutdown(self, exc=None):
async def shutdown(self, exc: Exception | None = None) -> None:
"""Shutdown the API events.
This should be called before closing the event loop.
"""
await self._reset_protocol(exc)
self._shutdown = True

async def _get_response(self, msg, timeout):
async def _get_response(
self, msg: Message, timeout: float | None
) -> tuple[BlockwiseRequest, Message]:
"""Perform the request, get the response."""
try:
protocol = await self._get_protocol()
pr_req = protocol.request(msg)
pr_resp = await asyncio.wait_for(pr_req.response, timeout)
pr_req: BlockwiseRequest = protocol.request(msg)
pr_resp: Message = await asyncio.wait_for(pr_req.response, timeout)
return pr_req, pr_resp
except CredentialsMissingError as exc:
await self._reset_protocol(exc)
Expand All @@ -131,7 +152,7 @@ async def _get_response(self, msg, timeout):
await self._update_credentials()
raise exc

async def _execute(self, api_command, timeout):
async def _execute(self, api_command: Command, timeout: float | None) -> Any:
"""Execute the command."""
if api_command.observe:
await self._observe(api_command, timeout)
Expand Down Expand Up @@ -177,33 +198,22 @@ async def _execute(self, api_command, timeout):

return api_command.result

def debug_comm(self, call_type, api_commands):
"""Log request/return."""
if not isinstance(api_commands, list):
api_msg = "single: "
api_commands = [api_commands]
else:
api_msg = "multiple: "
for api_command in api_commands:
if hasattr(api_command, "__dict__"):
api_msg += f"<<<{vars(api_command)}>>>"
else:
api_msg += f"+++{api_commands}+++"
msg = f"REQUEST {call_type}: {self._host} {api_msg}"
_LOGGER.debug(msg)

async def request(self, api_commands, timeout=None):
async def request(
self, api_commands: Command | list[Command], timeout: float | None = None
) -> Any:
"""Make a request."""
self.debug_comm("call", api_commands)
if not isinstance(api_commands, list):
_LOGGER.debug("REQUEST call single: %s %s", self._host, api_commands)
result = await self._execute(api_commands, timeout)
self.debug_comm("return", result)
_LOGGER.debug("REQUEST result single: %s", result)

return result

_LOGGER.debug("REQUEST call multiple: %s %s", self._host, api_commands)
commands = (self._execute(api_command, timeout) for api_command in api_commands)
command_results = await asyncio.gather(*commands)
command_results: list = await asyncio.gather(*commands)
_LOGGER.debug("REQUEST result multiple: %s", command_results)

self.debug_comm("return", command_results)
return command_results

async def _observe(self, api_command: Command, timeout: float | None) -> None:
Expand All @@ -219,22 +229,26 @@ async def _observe(self, api_command: Command, timeout: float | None) -> None:

api_command.result = _process_output(pr_rsp)

def success_callback(res):
def success_callback(res: Message) -> None:
api_command.result = _process_output(res)

def error_callback(exc: Exception) -> None:
if isinstance(exc, LibraryShutdown):
_LOGGER.debug("Protocol is shutdown, stopping observation")
return

if err_callback:
err_callback(exc)

observation = pr_req.observation
# The observation is set on the request
# since we pass a Message with observe set above.
assert observation is not None
observation.register_callback(success_callback)
observation.register_errback(error_callback)
self._observations_err_callbacks.append(observation.error)

async def generate_psk(self, security_key):
async def generate_psk(self, security_key: str) -> str:
"""Generate and set a psk from the security key."""
if not self._psk:
# Set context once for generating key
Expand All @@ -251,7 +265,7 @@ async def generate_psk(self, security_key):
}
)

self._psk = await self.request(command)
self._psk = cast(str, await self.request(command))

# aiocoap has now cached our psk, so it must be reset.
# We also no longer need the protocol, so this will clean that up.
Expand All @@ -260,7 +274,7 @@ async def generate_psk(self, security_key):

return self._psk

async def _update_credentials(self):
async def _update_credentials(self) -> None:
"""Update credentials."""
if not self._psk:
# No credentials to reset
Expand All @@ -278,9 +292,9 @@ async def _update_credentials(self):
)


def _process_output(res, parse_json=True):
def _process_output(res: Message, parse_json: bool = True) -> list | dict | str | None:
"""Process output."""
res_payload = res.payload.decode("utf-8")
res_payload: str = res.payload.decode("utf-8")
output = res_payload.strip()

_LOGGER.debug("Status: %s, Received: %s", res.code, output)
Expand All @@ -300,4 +314,4 @@ def _process_output(res, parse_json=True):
if not parse_json:
return output

return json.loads(output)
return cast(Union[list, dict], json.loads(output))
12 changes: 6 additions & 6 deletions pytradfri/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from typing import Any, Callable, Optional

TypeProcessResultCb = Optional[Callable[[Any], Optional[Any]]]
TypeProcessResultCb = Optional[Callable[[Any], Any]]


class Command:
Expand All @@ -31,8 +31,8 @@ def __init__(
self._err_callback = err_callback
self._observe = observe
self._observe_duration = observe_duration
self._raw_result: str | None = None
self._result: str | None = None
self._raw_result: list | dict | str | None = None
self._result: Any = None

@property
def method(self) -> str:
Expand Down Expand Up @@ -75,17 +75,17 @@ def observe_duration(self) -> int:
return self._observe_duration

@property
def raw_result(self) -> str | None:
def raw_result(self) -> list | dict | str | None:
"""Return raw result."""
return self._raw_result

@property
def result(self) -> str | None:
def result(self) -> Any:
"""Return result."""
return self._result

@result.setter
def result(self, value: str) -> None:
def result(self, value: list | dict | str | None) -> None:
"""Return command result."""
if self._process_result:
self._result = self._process_result(value)
Expand Down

0 comments on commit 8295e49

Please sign in to comment.