Skip to content

Commit

Permalink
feat: query profiling part 2: asynchronous (#961)
Browse files Browse the repository at this point in the history
* feat: support query profiling

* collection

* fix unit tests

* unit tests

* vector get and stream, unit tests

* aggregation get and stream, unit tests

* docstring

* query profile unit tests

* update base classes' method signature

* documentsnapshotlist unit tests

* func signatures

* undo client.py change

* transaction.get()

* lint

* system test

* fix shim test

* fix sys test

* fix sys test

* system test

* another system test

* skip system test in emulator

* stream generator unit tests

* coverage

* add system tests

* small fixes

* undo document change

* add system tests

* vector query system tests

* format

* fix system test

* comments

* add system tests

* improve stream generator

* type checking

* adding stars

* delete comment

* remove coverage requirements for type checking part

* add explain_options to StreamGenerator

* yield tuple instead

* raise exception when explain_metrics is absent

* refactor documentsnapshotlist into queryresultslist

* add comment

* improve type hint

* lint

* move QueryResultsList to stream_generator.py

* aggregation related type annotation

* transaction return type hint

* refactor QueryResultsList

* change stream generator to return ExplainMetrics instead of yield

* update aggregation query to use the new generator

* update query to use the new generator

* update vector query to use the new generator

* lint

* type annotations

* fix type annotation to be python 3.9 compatible

* fix type hint for python 3.8

* fix system test

* add test coverage

* use class method get_explain_metrics() instead of property explain_metrics

* feat: add explain_metrics to async generator

* async support for query

* system tests for query

* query profile for async vector query

* vector query system test

* async transaction

* async transaction system test

* async collection

* fix system test

* test coverage

* test coverage

* collection system test

* async aggregation

* lint

* cover

* lint

* aggregation system tests

* cover and fix system test

* delete type ignore

* improve type annotation

* mypy

* mypy

* address comments

* delete comments

* address comments
  • Loading branch information
Linchin authored Sep 20, 2024
1 parent 1d2a494 commit 060a3ef
Show file tree
Hide file tree
Showing 25 changed files with 2,189 additions and 276 deletions.
68 changes: 54 additions & 14 deletions google/cloud/firestore_v1/async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,23 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union
from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Union

from google.api_core import gapic_v1
from google.api_core import retry_async as retries

from google.cloud.firestore_v1 import transaction
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_aggregation import (
AggregationResult,
BaseAggregationQuery,
_query_response_to_result,
)
from google.cloud.firestore_v1.query_results import QueryResultsList

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_aggregation import AggregationResult
from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions
import google.cloud.firestore_v1.types.query_profile as query_profile_pb


class AsyncAggregationQuery(BaseAggregationQuery):
Expand All @@ -53,7 +55,9 @@ async def get(
retries.AsyncRetry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> List[List[AggregationResult]]:
*,
explain_options: Optional[ExplainOptions] = None,
) -> QueryResultsList[List[AggregationResult]]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages.
Expand All @@ -69,23 +73,39 @@ async def get(
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
List[List[AggregationResult]]: The aggregation query results
QueryResultsList[List[AggregationResult]]: The aggregation query results.
"""
explain_metrics: ExplainMetrics | None = None

stream_result = self.stream(
transaction=transaction, retry=retry, timeout=timeout
transaction=transaction,
retry=retry,
timeout=timeout,
explain_options=explain_options,
)
result = [aggregation async for aggregation in stream_result]
return result # type: ignore

if explain_options is None:
explain_metrics = None
else:
explain_metrics = await stream_result.get_explain_metrics()

return QueryResultsList(result, explain_options, explain_metrics)

async def _make_stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> Union[AsyncGenerator[List[AggregationResult], None]]:
explain_options: Optional[ExplainOptions] = None,
) -> AsyncGenerator[List[AggregationResult] | query_profile_pb.ExplainMetrics, Any]:
"""Internal method for stream(). Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator which
Expand All @@ -105,15 +125,23 @@ async def _make_stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Yields:
:class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`:
The result of aggregations of this query
List[AggregationResult] | query_profile_pb.ExplainMetrics:
The result of aggregations of this query. Query results will be
yielded as `List[AggregationResult]`. When the result contains
returned explain metrics, yield `query_profile_pb.ExplainMetrics`
individually.
"""
request, kwargs = self._prep_stream(
transaction,
retry,
timeout,
explain_options,
)

response_iterator = await self._client._firestore_api.run_aggregation_query(
Expand All @@ -124,14 +152,21 @@ async def _make_stream(

async for response in response_iterator:
result = _query_response_to_result(response)
yield result
if result:
yield result

if response.explain_metrics:
metrics = response.explain_metrics
yield metrics

def stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "AsyncStreamGenerator[DocumentSnapshot]":
*,
explain_options: Optional[ExplainOptions] = None,
) -> AsyncStreamGenerator[List[AggregationResult]]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
Expand All @@ -150,15 +185,20 @@ def stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
`AsyncStreamGenerator[DocumentSnapshot]`:
`AsyncStreamGenerator[List[AggregationResult]]`:
A generator of the query results.
"""

inner_generator = self._make_stream(
transaction=transaction,
retry=retry,
timeout=timeout,
explain_options=explain_options,
)
return AsyncStreamGenerator(inner_generator)
return AsyncStreamGenerator(inner_generator, explain_options)
26 changes: 23 additions & 3 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Classes for representing collections for the Google Cloud Firestore API."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple

Expand All @@ -35,6 +36,8 @@
if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.query_results import QueryResultsList


class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
Expand Down Expand Up @@ -192,7 +195,9 @@ async def get(
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> list:
*,
explain_options: Optional[ExplainOptions] = None,
) -> QueryResultsList[DocumentSnapshot]:
"""Read the documents in this collection.
This sends a ``RunQuery`` RPC and returns a list of documents
Expand All @@ -207,14 +212,21 @@ async def get(
system-specified policy.
timeout (Otional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
If a ``transaction`` is used and it already has write operations added,
this method cannot be used (i.e. read-after-write is not allowed).
Returns:
list: The documents in this collection that match the query.
QueryResultsList[DocumentSnapshot]:
The documents in this collection that match the query.
"""
query, kwargs = self._prep_get_or_stream(retry, timeout)
if explain_options is not None:
kwargs["explain_options"] = explain_options

return await query.get(transaction=transaction, **kwargs)

Expand All @@ -223,7 +235,9 @@ def stream(
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "AsyncStreamGenerator[DocumentSnapshot]":
*,
explain_options: Optional[ExplainOptions] = None,
) -> AsyncStreamGenerator[DocumentSnapshot]:
"""Read the documents in this collection.
This sends a ``RunQuery`` RPC and then returns a generator which
Expand All @@ -250,11 +264,17 @@ def stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
`AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query
results.
"""
query, kwargs = self._prep_get_or_stream(retry, timeout)
if explain_options:
kwargs["explain_options"] = explain_options

return query.stream(transaction=transaction, **kwargs)
Loading

0 comments on commit 060a3ef

Please sign in to comment.