Skip to content

Commit

Permalink
feat: use generator for stream results (#926)
Browse files Browse the repository at this point in the history
* feat: use iterator for query results

* use iterator as suggested in comments

* remove unnecessary class

* more iterators

* undo adding await

* address comments

* undo bundle change

* undo bundle change

* cleanups and docstrings

* fix type hint

* unit tests

* lint

* skip tests with anext for python < 3.10

* lint

* address comments

* lint

* fix type hint

* type hint

* sys test debug

* sys test debug

* undo change for debug

* address comment

* system test debug

* undo system test debug code
  • Loading branch information
Linchin authored Jul 9, 2024
1 parent 104293b commit 3e5df35
Show file tree
Hide file tree
Showing 16 changed files with 703 additions and 175 deletions.
77 changes: 58 additions & 19 deletions google/cloud/firestore_v1/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
BaseAggregationQuery,
_query_response_to_result,
)
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.stream_generator import StreamGenerator

from typing import Generator, Union, List, Any
from typing import Any, Generator, List, Optional, TYPE_CHECKING, Union

# Types needed only for Type Hints
if TYPE_CHECKING:
from google.cloud.firestore_v1 import transaction # pragma: NO COVER


class AggregationQuery(BaseAggregationQuery):
Expand Down Expand Up @@ -99,36 +105,34 @@ def _retry_query_after_exception(self, exc, retry, transaction):

return False

def stream(
def _make_stream(
self,
transaction=None,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> Union[Generator[List[AggregationResult], Any, None]]:
"""Runs the aggregation query.
"""Internal method for stream(). Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns an iterator which
consumes each document returned in the stream of ``RunAggregationQueryResponse``
messages.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
which consumes each document returned in the stream of
``RunAggregationQueryResponse`` messages.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
If a ``transaction`` is used and it already has write operations added,
this method cannot be used (i.e. read-after-write is not allowed).
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
retry (Optional[google.api_core.retry.Retry]): Designation of what
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
Yields:
:class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`:
The result of aggregations of this query
The result of aggregations of this query.
"""

response_iterator = self._get_stream_iterator(
Expand All @@ -154,3 +158,38 @@ def stream(
break
result = _query_response_to_result(response)
yield result

def stream(
self,
transaction: Optional["transaction.Transaction"] = None,
retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "StreamGenerator[DocumentSnapshot]":
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
which consumes each document returned in the stream of
``RunAggregationQueryResponse`` messages.
If a ``transaction`` is used and it already has write operations added,
this method cannot be used (i.e. read-after-write is not allowed).
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (Optional[google.api_core.retry.Retry]): Designation of what
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Optinal[float]): The timeout for this request. Defaults
to a system-specified value.
Returns:
`StreamGenerator[DocumentSnapshot]`: A generator of the query results.
"""
inner_generator = self._make_stream(
transaction=transaction,
retry=retry,
timeout=timeout,
)
return StreamGenerator(inner_generator)
74 changes: 57 additions & 17 deletions google/cloud/firestore_v1/async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@
from google.api_core import gapic_v1
from google.api_core import retry_async as retries

from typing import List, Union, AsyncGenerator

from typing import AsyncGenerator, List, Optional, Union, TYPE_CHECKING

from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_aggregation import (
AggregationResult,
_query_response_to_result,
BaseAggregationQuery,
)
from google.cloud.firestore_v1 import transaction

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.base_document import DocumentSnapshot


class AsyncAggregationQuery(BaseAggregationQuery):
Expand Down Expand Up @@ -76,17 +80,15 @@ async def get(
result = [aggregation async for aggregation in stream_result]
return result # type: ignore

async def stream(
async def _make_stream(
self,
transaction=None,
retry: Union[
retries.AsyncRetry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> Union[AsyncGenerator[List[AggregationResult], None]]:
"""Runs the aggregation query.
"""Internal method for stream(). Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns an iterator which
This sends a ``RunAggregationQuery`` RPC and then returns a generator which
consumes each document returned in the stream of ``RunAggregationQueryResponse``
messages.
Expand All @@ -95,13 +97,14 @@ async def stream(
allowed).
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\
Transaction`]):
An existing transaction that the query will run in.
retry (Optional[google.api_core.retry.Retry]): Designation of what
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
Yields:
:class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`:
Expand All @@ -122,3 +125,40 @@ async def stream(
async for response in response_iterator:
result = _query_response_to_result(response)
yield result

def stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "AsyncStreamGenerator[DocumentSnapshot]":
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
which consumes each document returned in the stream of
``RunAggregationQueryResponse`` messages.
If a ``transaction`` is used and it already has write operations added,
this method cannot be used (i.e. read-after-write is not allowed).
Args:
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\
Transaction`]):
An existing transaction that the query will run in.
retry (Optional[google.api_core.retry.Retry]): Designation of what
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
Returns:
`AsyncStreamGenerator[DocumentSnapshot]`:
A generator of the query results.
"""

inner_generator = self._make_stream(
transaction=transaction,
retry=retry,
timeout=timeout,
)
return AsyncStreamGenerator(inner_generator)
65 changes: 35 additions & 30 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@
BaseCollectionReference,
_item_to_document_ref,
)
from google.cloud.firestore_v1 import async_query, async_document, async_aggregation
from google.cloud.firestore_v1 import (
async_aggregation,
async_document,
async_query,
transaction,
)

from google.cloud.firestore_v1.document import DocumentReference

from typing import AsyncIterator
from typing import Any, AsyncGenerator, Tuple
from typing import Any, AsyncGenerator, Optional, Tuple, TYPE_CHECKING

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot


class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
Expand Down Expand Up @@ -176,9 +181,9 @@ async def list_documents(

async def get(
self,
transaction: Transaction = None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> list:
"""Read the documents in this collection.
Expand All @@ -189,14 +194,14 @@ async def get(
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
retry (Optional[google.api_core.retry.Retry]): Designation of what
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Otional[float]): The timeout for this request. Defaults
to a system-specified value.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
If a ``transaction`` is used and it already has write operations added,
this method cannot be used (i.e. read-after-write is not allowed).
Returns:
list: The documents in this collection that match the query.
Expand All @@ -205,15 +210,15 @@ async def get(

return await query.get(transaction=transaction, **kwargs)

async def stream(
def stream(
self,
transaction: Transaction = None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> AsyncIterator[async_document.DocumentSnapshot]:
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "AsyncStreamGenerator[DocumentSnapshot]":
"""Read the documents in this collection.
This sends a ``RunQuery`` RPC and then returns an iterator which
This sends a ``RunQuery`` RPC and then returns a generator which
consumes each document returned in the stream of ``RunQueryResponse``
messages.
Expand All @@ -232,16 +237,16 @@ async def stream(
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\
Transaction`]):
An existing transaction that the query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
retry (Optional[google.api_core.retry.Retry]): Designation of what
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
Yields:
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
The next document that fulfills the query.
Returns:
`AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query
results.
"""
query, kwargs = self._prep_get_or_stream(retry, timeout)

async for d in query.stream(transaction=transaction, **kwargs):
yield d # pytype: disable=name-error
return query.stream(transaction=transaction, **kwargs)
Loading

0 comments on commit 3e5df35

Please sign in to comment.