diff --git a/docs/aggregation.rst b/docs/aggregation.rst new file mode 100644 index 0000000000..ab9bf45467 --- /dev/null +++ b/docs/aggregation.rst @@ -0,0 +1,14 @@ +Aggregation +~~~~~~~~~~~ + +.. automodule:: google.cloud.firestore_v1.aggregation + :members: + :show-inheritance: + +.. automodule:: google.cloud.firestore_v1.base_aggregation + :members: + :show-inheritance: + +.. automodule:: google.cloud.firestore_v1.async_aggregation + :members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 3fce768ab7..8cf2a17e84 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ API Reference client collection + aggregation document field_path query diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py new file mode 100644 index 0000000000..609f82f75a --- /dev/null +++ b/google/cloud/firestore_v1/aggregation.py @@ -0,0 +1,156 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing aggregation queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from +a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be +a more common way to create an aggregation query than direct usage of the constructor. +""" +from __future__ import annotations + +from google.api_core import exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries + + +from google.cloud.firestore_v1.base_aggregation import ( + AggregationResult, + BaseAggregationQuery, + _query_response_to_result, +) + +from typing import Generator, Union, List, Any + + +class AggregationQuery(BaseAggregationQuery): + """Represents an aggregation query to the Firestore API.""" + + def __init__( + self, + nested_query, + ) -> None: + super(AggregationQuery, self).__init__(nested_query) + + def get( + self, + transaction=None, + retry: Union[ + retries.Retry, None, gapic_v1.method._MethodDefault + ] = gapic_v1.method.DEFAULT, + timeout: float | None = None, + ) -> List[AggregationResult]: + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The aggregation query results + + """ + result = self.stream(transaction=transaction, retry=retry, timeout=timeout) + return list(result) # type: ignore + + def _get_stream_iterator(self, transaction, retry, timeout): + """Helper method for :meth:`stream`.""" + request, kwargs = self._prep_stream( + transaction, + retry, + timeout, + ) + + return self._client._firestore_api.run_aggregation_query( + request=request, + metadata=self._client._rpc_metadata, + **kwargs, + ) + + def _retry_query_after_exception(self, exc, retry, transaction): + """Helper method for :meth:`stream`.""" + if transaction is None: # no snapshot-based retry inside transaction + if retry is gapic_v1.method.DEFAULT: + transport = self._client._firestore_api._transport + gapic_callable = transport.run_aggregation_query + retry = gapic_callable._retry + return retry._predicate(exc) + + return False + + def stream( + self, + transaction=None, + retry: Union[ + retries.Retry, None, gapic_v1.method._MethodDefault + ] = gapic_v1.method.DEFAULT, + timeout: float | None = None, + ) -> Union[Generator[List[AggregationResult], Any, None]]: + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunAggregationQueryResponse`` + messages. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Yields: + :class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`: + The result of aggregations of this query + """ + + response_iterator = self._get_stream_iterator( + transaction, + retry, + timeout, + ) + while True: + try: + response = next(response_iterator, None) + except exceptions.GoogleAPICallError as exc: + if self._retry_query_after_exception(exc, retry, transaction): + response_iterator = self._get_stream_iterator( + transaction, + retry, + timeout, + ) + continue + else: + raise + + if response is None: # EOI + break + result = _query_response_to_result(response) + yield result diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py new file mode 100644 index 0000000000..194016cd23 --- /dev/null +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -0,0 +1,124 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing Async aggregation queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery` can be created directly from +a :class:`~google.cloud.firestore_v1.async_collection.AsyncCollection` and that can be +a more common way to create an aggregation query than direct usage of the constructor. +""" +from __future__ import annotations + +from google.api_core import gapic_v1 +from google.api_core import retry as retries + +from typing import List, Union, AsyncGenerator + + +from google.cloud.firestore_v1.base_aggregation import ( + AggregationResult, + _query_response_to_result, + BaseAggregationQuery, +) + + +class AsyncAggregationQuery(BaseAggregationQuery): + """Represents an aggregation query to the Firestore API.""" + + def __init__( + self, + nested_query, + ) -> None: + super(AsyncAggregationQuery, self).__init__(nested_query) + + async def get( + self, + transaction=None, + retry: Union[ + retries.Retry, None, gapic_v1.method._MethodDefault + ] = gapic_v1.method.DEFAULT, + timeout: float | None = None, + ) -> List[AggregationResult]: + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The aggregation query results + + """ + stream_result = self.stream( + transaction=transaction, retry=retry, timeout=timeout + ) + result = [aggregation async for aggregation in stream_result] + return result # type: ignore + + async def stream( + self, + transaction=None, + retry: Union[ + retries.Retry, None, gapic_v1.method._MethodDefault + ] = gapic_v1.method.DEFAULT, + timeout: float | None = None, + ) -> Union[AsyncGenerator[List[AggregationResult], None]]: + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunAggregationQueryResponse`` + messages. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Yields: + :class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`: + The result of aggregations of this query + """ + request, kwargs = self._prep_stream( + transaction, + retry, + timeout, + ) + + response_iterator = await self._client._firestore_api.run_aggregation_query( + request=request, + metadata=self._client._rpc_metadata, + **kwargs, + ) + + async for response in response_iterator: + result = _query_response_to_result(response) + yield result diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 52847a3dcf..e997455092 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -21,10 +21,7 @@ BaseCollectionReference, _item_to_document_ref, ) -from google.cloud.firestore_v1 import ( - async_query, - async_document, -) +from google.cloud.firestore_v1 import async_query, async_document, async_aggregation from google.cloud.firestore_v1.document import DocumentReference @@ -72,6 +69,14 @@ def _query(self) -> async_query.AsyncQuery: """ return async_query.AsyncQuery(self) + def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery: + """AsyncAggregationQuery factory. + + Returns: + :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery + """ + return async_aggregation.AsyncAggregationQuery(self._query()) + async def _chunkify(self, chunk_size: int): async for page in self._query()._chunkify(chunk_size): yield page diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 1ad0459f74..efa172520a 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -18,6 +18,7 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ +from __future__ import annotations from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -39,6 +40,8 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery + class AsyncQuery(BaseQuery): """Represents a query to the Firestore API. @@ -213,6 +216,21 @@ async def get( return result + def count( + self, alias: str | None = None + ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + """Adds a count over the nested query. + + Args: + alias + (Optional[str]): The alias for the count + + Returns: + :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: + An instance of an AsyncAggregationQuery object + """ + return AsyncAggregationQuery(self).count(alias=alias) + async def stream( self, transaction=None, diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py new file mode 100644 index 0000000000..b7a6605b87 --- /dev/null +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -0,0 +1,221 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing aggregation queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from +a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be +a more common way to create an aggregation query than direct usage of the constructor. +""" + + +from __future__ import annotations + +import abc + + +from abc import ABC + +from typing import List, Coroutine, Union, Tuple, Generator, Any, AsyncGenerator + +from google.api_core import gapic_v1 +from google.api_core import retry as retries + + +from google.cloud.firestore_v1.types import RunAggregationQueryResponse + +from google.cloud.firestore_v1.types import StructuredAggregationQuery +from google.cloud.firestore_v1 import _helpers + + +class AggregationResult(object): + """ + A class representing result from Aggregation Query + :type alias: str + :param alias: The alias for the aggregation. + :type value: int + :param value: The resulting value from the aggregation. + :type read_time: + :param value: The resulting read_time + """ + + def __init__(self, alias: str, value: int, read_time=None): + self.alias = alias + self.value = value + self.read_time = read_time + + def __repr__(self): + return f"" + + +class BaseAggregation(ABC): + @abc.abstractmethod + def _to_protobuf(self): + """Convert this instance to the protobuf representation""" + + +class CountAggregation(BaseAggregation): + def __init__(self, alias: str | None = None): + self.alias = alias + + def _to_protobuf(self): + """Convert this instance to the protobuf representation""" + aggregation_pb = StructuredAggregationQuery.Aggregation() + aggregation_pb.alias = self.alias + aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() + return aggregation_pb + + +def _query_response_to_result( + response_pb: RunAggregationQueryResponse, +) -> List[AggregationResult]: + results = [ + AggregationResult( + alias=key, + value=response_pb.result.aggregate_fields[key].integer_value, + read_time=response_pb.read_time, + ) + for key in response_pb.result.aggregate_fields.pb.keys() + ] + + return results + + +class BaseAggregationQuery(ABC): + """Represents an aggregation query to the Firestore API.""" + + def __init__( + self, + nested_query, + ) -> None: + self._nested_query = nested_query + self._collection_ref = nested_query._parent + self._aggregations: List[BaseAggregation] = [] + + @property + def _client(self): + return self._collection_ref._client + + def count(self, alias: str | None = None): + """ + Adds a count over the nested query + """ + count_aggregation = CountAggregation(alias=alias) + self._aggregations.append(count_aggregation) + return self + + def add_aggregation(self, aggregation: BaseAggregation) -> None: + """ + Adds an aggregation operation to the nested query + + :type aggregation: :class:`google.cloud.firestore_v1.aggregation.BaseAggregation` + :param aggregation: An aggregation operation, e.g. a CountAggregation + """ + self._aggregations.append(aggregation) + + def add_aggregations(self, aggregations: List[BaseAggregation]) -> None: + """ + Adds a list of aggregations to the nested query + + :type aggregations: list + :param aggregations: a list of aggregation operations + """ + self._aggregations.extend(aggregations) + + def _to_protobuf(self) -> StructuredAggregationQuery: + pb = StructuredAggregationQuery() + pb.structured_query = self._nested_query._to_protobuf() + + for aggregation in self._aggregations: + aggregation_pb = aggregation._to_protobuf() + pb.aggregations.append(aggregation_pb) + return pb + + def _prep_stream( + self, + transaction=None, + retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None, + timeout: float | None = None, + ) -> Tuple[dict, dict]: + parent_path, expected_prefix = self._collection_ref._parent_info() + request = { + "parent": parent_path, + "structured_aggregation_query": self._to_protobuf(), + "transaction": _helpers.get_transaction_id(transaction), + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + return request, kwargs + + @abc.abstractmethod + def get( + self, + transaction=None, + retry: Union[ + retries.Retry, None, gapic_v1.method._MethodDefault + ] = gapic_v1.method.DEFAULT, + timeout: float | None = None, + ) -> List[AggregationResult] | Coroutine[Any, Any, List[AggregationResult]]: + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The aggregation query results + + """ + + @abc.abstractmethod + def stream( + self, + transaction=None, + retry: Union[ + retries.Retry, None, gapic_v1.method._MethodDefault + ] = gapic_v1.method.DEFAULT, + timeout: float | None = None, + ) -> Generator[List[AggregationResult], Any, None] | AsyncGenerator[ + List[AggregationResult], None + ]: + """Runs the aggregation query. + + This sends a``RunAggregationQuery`` RPC and returns an iterator in the stream of ``RunAggregationQueryResponse`` messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The aggregation query results + + """ diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index e9d9867f8d..b8781d236e 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -19,6 +19,9 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + from typing import ( Any, AsyncGenerator, @@ -107,6 +110,9 @@ def parent(self): def _query(self) -> BaseQuery: raise NotImplementedError + def _aggregation_query(self) -> BaseAggregationQuery: + raise NotImplementedError + def document(self, document_id: str = None) -> DocumentReference: """Create a sub-document underneath the current collection. @@ -474,6 +480,15 @@ def stream( def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError + def count(self, alias=None): + """ + Adds a count over the nested query. + + :type alias: str + :param alias: (Optional) The alias for the count + """ + return self._aggregation_query().count(alias=alias) + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 9ac7735afd..1d430a1e91 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -18,12 +18,15 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ +from __future__ import annotations + import copy import math from google.api_core import retry as retries from google.protobuf import wrappers_pb2 +from google.cloud import firestore_v1 from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import document from google.cloud.firestore_v1 import field_path as field_path_module @@ -806,6 +809,11 @@ def _to_protobuf(self) -> StructuredQuery: query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) return query.StructuredQuery(**query_kwargs) + def count( + self, alias: str | None = None + ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + raise NotImplementedError + def get( self, transaction=None, diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index c0fb55b78e..51ee311798 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -22,6 +22,7 @@ _item_to_document_ref, ) from google.cloud.firestore_v1 import query as query_mod +from google.cloud.firestore_v1 import aggregation from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import document from typing import Any, Callable, Generator, Tuple @@ -67,6 +68,14 @@ def _query(self) -> query_mod.Query: """ return query_mod.Query(self) + def _aggregation_query(self) -> aggregation.AggregationQuery: + """AggregationQuery factory. + + Returns: + :class:`~google.cloud.firestore_v1.aggregation_query.AggregationQuery` + """ + return aggregation.AggregationQuery(self._query()) + def add( self, document_data: dict, diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 49e8013c87..700493725f 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -18,6 +18,8 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ +from __future__ import annotations + from google.cloud import firestore_v1 from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.api_core import exceptions @@ -32,6 +34,7 @@ _collection_group_query_response_to_snapshot, _enum_from_direction, ) +from google.cloud.firestore_v1 import aggregation from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch @@ -234,6 +237,17 @@ def _retry_query_after_exception(self, exc, retry, transaction): return False + def count( + self, alias: str | None = None + ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + """ + Adds a count over the query. + + :type alias: str + :param alias: (Optional) The alias for the count + """ + return aggregation.AggregationQuery(self).count(alias=alias) + def stream( self, transaction=None, diff --git a/tests/system/test_system.py b/tests/system/test_system.py index c8a476e305..e51cd7ba23 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -17,6 +17,7 @@ import math import operator +import google.auth from google.oauth2 import service_account import pytest @@ -46,11 +47,13 @@ def _get_credentials_and_project(): if FIRESTORE_EMULATOR: credentials = EMULATOR_CREDS project = FIRESTORE_PROJECT - else: + elif FIRESTORE_CREDS: credentials = service_account.Credentials.from_service_account_file( FIRESTORE_CREDS ) project = FIRESTORE_PROJECT or credentials.project_id + else: + credentials, project = google.auth.default() return credentials, project @@ -536,6 +539,13 @@ def query_docs(client): operation() +@pytest.fixture +def query(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("a", "==", 1) + return query + + def test_query_stream_w_simple_field_eq_op(query_docs): collection, stored, allowed_vals = query_docs query = collection.where("a", "==", 1) @@ -1617,3 +1627,199 @@ def test_repro_391(client, cleanup): _, document = collection.add(data, document_id) assert len(set(collection.stream())) == len(document_ids) + + +def test_count_query_get_default_alias(query): + count_query = query.count() + result = count_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "field_1" + + +def test_count_query_get_with_alias(query): + count_query = query.count(alias="total") + result = count_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + + +def test_count_query_get_with_limit(query): + # count without limit + count_query = query.count(alias="total") + result = count_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 + + # count with limit + count_query = query.limit(2).count(alias="total") + + result = count_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 2 + + +def test_count_query_get_multiple_aggregations(query): + count_query = query.count(alias="total").count(alias="all") + + result = count_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +def test_count_query_get_multiple_aggregations_duplicated_alias(query): + count_query = query.count(alias="total").count(alias="total") + + with pytest.raises(InvalidArgument) as exc_info: + count_query.get() + + assert "Aggregation aliases contain duplicate alias" in exc_info.value.message + + +def test_count_query_get_empty_aggregation(query): + from google.cloud.firestore_v1.aggregation import AggregationQuery + + aggregation_query = AggregationQuery(query) + + with pytest.raises(InvalidArgument) as exc_info: + aggregation_query.get() + + assert "Aggregations can not be empty" in exc_info.value.message + + +def test_count_query_stream_default_alias(query): + count_query = query.count() + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + + +def test_count_query_stream_with_alias(query): + + count_query = query.count(alias="total") + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + + +def test_count_query_stream_with_limit(query): + # count without limit + count_query = query.count(alias="total") + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 5 + + # count with limit + count_query = query.limit(2).count(alias="total") + + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 2 + + +def test_count_query_stream_multiple_aggregations(query): + count_query = query.count(alias="total").count(alias="all") + + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +def test_count_query_stream_multiple_aggregations_duplicated_alias(query): + count_query = query.count(alias="total").count(alias="total") + + with pytest.raises(InvalidArgument) as exc_info: + for _ in count_query.stream(): + pass + + assert "Aggregation aliases contain duplicate alias" in exc_info.value.message + + +def test_count_query_stream_empty_aggregation(query): + from google.cloud.firestore_v1.aggregation import AggregationQuery + + aggregation_query = AggregationQuery(query) + + with pytest.raises(InvalidArgument) as exc_info: + for _ in aggregation_query.stream(): + pass + + assert "Aggregations can not be empty" in exc_info.value.message + + +@firestore.transactional +def create_in_transaction(collection_id, transaction, cleanup): + collection = client.collection(collection_id) + + query = collection.where("a", "==", 1) + count_query = query.count() + + result = count_query.get(transaction=transaction) + for r in result[0]: + assert r.value <= 2 + if r.value < 2: + document_id_3 = "doc3" + UNIQUE_RESOURCE_ID + document_3 = client.document(collection_id, document_id_3) + cleanup(document_3.delete) + document_3.create({"a": 1}) + else: + raise ValueError("Collection can't have more than 2 documents") + + +@firestore.transactional +def create_in_transaction_helper(transaction, client, collection_id, cleanup): + collection = client.collection(collection_id) + query = collection.where("a", "==", 1) + count_query = query.count() + result = count_query.get(transaction=transaction) + + for r in result[0]: + if r.value < 2: + document_id_3 = "doc3" + UNIQUE_RESOURCE_ID + document_3 = client.document(collection_id, document_id_3) + cleanup(document_3.delete) + document_3.create({"a": 1}) + else: # transaction is rolled back + raise ValueError("Collection can't have more than 2 docs") + + +def test_count_query_in_transaction(client, cleanup): + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + document_id_1 = "doc1" + UNIQUE_RESOURCE_ID + document_id_2 = "doc2" + UNIQUE_RESOURCE_ID + + document_1 = client.document(collection_id, document_id_1) + document_2 = client.document(collection_id, document_id_2) + + cleanup(document_1.delete) + cleanup(document_2.delete) + + document_1.create({"a": 1}) + document_2.create({"a": 1}) + + transaction = client.transaction() + + with pytest.raises(ValueError) as exc: + create_in_transaction_helper(transaction, client, collection_id, cleanup) + assert exc.exc_info == "Collection can't have more than 2 documents" + + collection = client.collection(collection_id) + + query = collection.where("a", "==", 1) + count_query = query.count() + result = count_query.get() + for r in result[0]: + assert r.value == 2 # there are still only 2 docs diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 662ce656f0..7b97f197c1 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -19,6 +19,8 @@ import pytest import pytest_asyncio import operator +import google.auth + from typing import Callable, Dict, List, Optional from google.oauth2 import service_account @@ -65,11 +67,13 @@ def _get_credentials_and_project(): if FIRESTORE_EMULATOR: credentials = EMULATOR_CREDS project = FIRESTORE_PROJECT - else: + elif FIRESTORE_CREDS: credentials = service_account.Credentials.from_service_account_file( FIRESTORE_CREDS ) project = FIRESTORE_PROJECT or credentials.project_id + else: + credentials, project = google.auth.default() return credentials, project @@ -579,6 +583,14 @@ async def query_docs(client): await operation() +@pytest_asyncio.fixture +async def async_query(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("a", "==", 1) + + return query + + async def test_query_stream_w_simple_field_eq_op(query_docs): collection, stored, allowed_vals = query_docs query = collection.where("a", "==", 1) @@ -1399,3 +1411,184 @@ async def _chain(*iterators): for iterator in iterators: async for value in iterator: yield value + + +async def test_count_async_query_get_default_alias(async_query): + count_query = async_query.count() + result = await count_query.get() + for r in result[0]: + assert r.alias == "field_1" + + +async def test_async_count_query_get_with_alias(async_query): + + count_query = async_query.count(alias="total") + result = await count_query.get() + for r in result[0]: + assert r.alias == "total" + + +async def test_async_count_query_get_with_limit(async_query): + + count_query = async_query.count(alias="total") + result = await count_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 + + # count with limit + count_query = async_query.limit(2).count(alias="total") + result = await count_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 2 + + +async def test_async_count_query_get_multiple_aggregations(async_query): + + count_query = async_query.count(alias="total").count(alias="all") + + result = await count_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +async def test_async_count_query_get_multiple_aggregations_duplicated_alias( + async_query, +): + + count_query = async_query.count(alias="total").count(alias="total") + + with pytest.raises(InvalidArgument) as exc_info: + await count_query.get() + + assert "Aggregation aliases contain duplicate alias" in exc_info.value.message + + +async def test_async_count_query_get_empty_aggregation(async_query): + from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery + + aggregation_query = AsyncAggregationQuery(async_query) + + with pytest.raises(InvalidArgument) as exc_info: + await aggregation_query.get() + + assert "Aggregations can not be empty" in exc_info.value.message + + +async def test_count_async_query_stream_default_alias(async_query): + + count_query = async_query.count() + + async for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + + +async def test_async_count_query_stream_with_alias(async_query): + + count_query = async_query.count(alias="total") + async for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + + +async def test_async_count_query_stream_with_limit(async_query): + # count without limit + count_query = async_query.count(alias="total") + async for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 5 + + # count with limit + count_query = async_query.limit(2).count(alias="total") + async for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 2 + + +async def test_async_count_query_stream_multiple_aggregations(async_query): + + count_query = async_query.count(alias="total").count(alias="all") + + async for result in count_query.stream(): + assert len(result) == 2 + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( + async_query, +): + + count_query = async_query.count(alias="total").count(alias="total") + + with pytest.raises(InvalidArgument) as exc_info: + async for _ in count_query.stream(): + pass + + assert "Aggregation aliases contain duplicate alias" in exc_info.value.message + + +async def test_async_count_query_stream_empty_aggregation(async_query): + from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery + + aggregation_query = AsyncAggregationQuery(async_query) + + with pytest.raises(InvalidArgument) as exc_info: + async for _ in aggregation_query.stream(): + pass + + assert "Aggregations can not be empty" in exc_info.value.message + + +@firestore.async_transactional +async def create_in_transaction_helper(transaction, client, collection_id, cleanup): + collection = client.collection(collection_id) + query = collection.where("a", "==", 1) + count_query = query.count() + result = await count_query.get(transaction=transaction) + + for r in result[0]: + if r.value < 2: + document_id_3 = "doc3" + UNIQUE_RESOURCE_ID + document_3 = client.document(collection_id, document_id_3) + cleanup(document_3.delete) + document_3.create({"a": 1}) + else: # transaction is rolled back + raise ValueError("Collection can't have more than 2 docs") + + +async def test_count_query_in_transaction(client, cleanup): + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + document_id_1 = "doc1" + UNIQUE_RESOURCE_ID + document_id_2 = "doc2" + UNIQUE_RESOURCE_ID + + document_1 = client.document(collection_id, document_id_1) + document_2 = client.document(collection_id, document_id_2) + + cleanup(document_1.delete) + cleanup(document_2.delete) + + await document_1.create({"a": 1}) + await document_2.create({"a": 1}) + + transaction = client.transaction() + + with pytest.raises(ValueError) as exc: + await create_in_transaction_helper(transaction, client, collection_id, cleanup) + assert exc.exc_info == "Collection can't have more than 2 documents" + + collection = client.collection(collection_id) + + query = collection.where("a", "==", 1) + count_query = query.count() + result = await count_query.get() + for r in result[0]: + assert r.value == 2 # there are still only 2 docs diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 3b09f9f9ad..5ff2891945 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -18,15 +18,19 @@ import typing import google + +from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.base_client import BaseClient from google.cloud.firestore_v1.document import DocumentReference, DocumentSnapshot from google.cloud._helpers import _datetime_to_pb_timestamp, UTC # type: ignore from google.cloud.firestore_v1._helpers import build_timestamp -from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.client import Client from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +DEFAULT_TEST_PROJECT = "project-project" + + def make_test_credentials() -> google.auth.credentials.Credentials: # type: ignore import google.auth.credentials # type: ignore @@ -35,13 +39,63 @@ def make_test_credentials() -> google.auth.credentials.Credentials: # type: ign def make_client(project_name: typing.Optional[str] = None) -> Client: return Client( - project=project_name or "project-project", + project=project_name or DEFAULT_TEST_PROJECT, credentials=make_test_credentials(), ) -def make_async_client() -> AsyncClient: - return AsyncClient(project="project-project", credentials=make_test_credentials()) +def make_async_client(project=DEFAULT_TEST_PROJECT) -> AsyncClient: + return AsyncClient(project=project, credentials=make_test_credentials()) + + +def make_query(*args, **kwargs): + from google.cloud.firestore_v1.query import Query + + return Query(*args, **kwargs) + + +def make_async_query(*args, **kwargs): + from google.cloud.firestore_v1.async_query import AsyncQuery + + return AsyncQuery(*args, **kwargs) + + +def make_aggregation_query(*args, **kw): + from google.cloud.firestore_v1.aggregation import AggregationQuery + + return AggregationQuery(*args, **kw) + + +def make_async_aggregation_query(*args, **kw): + from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery + + return AsyncAggregationQuery(*args, **kw) + + +def make_aggregation_query_response(aggregations, read_time=None, transaction=None): + from google.cloud.firestore_v1.types import firestore + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import aggregation_result + + if read_time is None: + now = datetime.datetime.now(tz=datetime.timezone.utc) + read_time = _datetime_to_pb_timestamp(now) + + res = {} + for aggr in aggregations: + res[aggr.alias] = aggr.value + result = aggregation_result.AggregationResult( + aggregate_fields=_helpers.encode_dict(res) + ) + + kwargs = {} + kwargs["read_time"] = read_time + kwargs["result"] = result + if transaction is not None: + kwargs["transaction"] = transaction + + return firestore.RunAggregationQueryResponse(**kwargs) def build_test_timestamp( diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 95cb595716..0a6dee40e3 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -19,6 +19,9 @@ import pytest +from tests.unit.v1._test_helpers import make_test_credentials + + def _make_geo_point(lat, lng): from google.cloud.firestore_v1._helpers import GeoPoint @@ -2564,16 +2567,10 @@ def _make_ref_string(project, database, *path): ) -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - - def _make_client(project="quark"): from google.cloud.firestore_v1.client import Client - credentials = _make_credentials() + credentials = make_test_credentials() return Client(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py new file mode 100644 index 0000000000..7b07aa9afa --- /dev/null +++ b/tests/unit/v1/test_aggregation.py @@ -0,0 +1,476 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +import mock +import pytest + + +from datetime import datetime, timezone, timedelta + +from google.cloud.firestore_v1.base_aggregation import ( + CountAggregation, + AggregationResult, +) +from tests.unit.v1._test_helpers import ( + make_aggregation_query, + make_aggregation_query_response, + make_client, + make_query, +) + +_PROJECT = "PROJECT" + + +def test_count_aggregation_to_pb(): + from google.cloud.firestore_v1.types import query as query_pb2 + + count_aggregation = CountAggregation(alias="total") + + expected_aggregation_query_pb = query_pb2.StructuredAggregationQuery.Aggregation() + expected_aggregation_query_pb.count = ( + query_pb2.StructuredAggregationQuery.Aggregation.Count() + ) + expected_aggregation_query_pb.alias = count_aggregation.alias + assert count_aggregation._to_protobuf() == expected_aggregation_query_pb + + +def test_aggregation_query_constructor(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + assert aggregation_query._collection_ref == query._parent + assert aggregation_query._nested_query == query + assert len(aggregation_query._aggregations) == 0 + assert aggregation_query._client == query._parent._client + + +def test_aggregation_query_add_aggregation(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.add_aggregation(CountAggregation(alias="all")) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + + +def test_aggregation_query_add_aggregations(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.add_aggregations( + [CountAggregation(alias="all"), CountAggregation(alias="total")] + ) + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[1].alias == "total" + + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert isinstance(aggregation_query._aggregations[1], CountAggregation) + + +def test_aggregation_query_count(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + + +def test_aggregation_query_count_twice(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all").count(alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[1].alias == "total" + + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert isinstance(aggregation_query._aggregations[1], CountAggregation) + + +def test_aggregation_query_to_protobuf(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all") + pb = aggregation_query._to_protobuf() + + assert pb.structured_query == parent._query()._to_protobuf() + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == aggregation_query._aggregations[0]._to_protobuf() + + +def test_aggregation_query_prep_stream(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all") + + request, kwargs = aggregation_query._prep_stream() + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + } + assert request == expected_request + assert kwargs == {"retry": None} + + +def test_aggregation_query_prep_stream_with_transaction(): + client = make_client() + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all") + + request, kwargs = aggregation_query._prep_stream(transaction=transaction) + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + } + assert request == expected_request + assert kwargs == {"retry": None} + + +def _aggregation_query_get_helper(retry=None, timeout=None, read_time=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud._helpers import _datetime_to_pb_timestamp + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_aggregation_query"]) + + # Attach the fake GAPIC to a real client. + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.count(alias="all") + + aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + response_pb = make_aggregation_query_response( + [aggregation_result], read_time=read_time + ) + firestore_api.run_aggregation_query.return_value = iter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + returned = aggregation_query.get(**kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + if read_time is not None: + result_datetime = _datetime_to_pb_timestamp(r.read_time) + assert result_datetime == read_time + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_aggregation_query_get(): + _aggregation_query_get_helper() + + +def test_aggregation_query_get_with_readtime(): + from google.cloud._helpers import _datetime_to_pb_timestamp + + one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) + read_time = _datetime_to_pb_timestamp(one_hour_ago) + _aggregation_query_get_helper(read_time=read_time) + + +def test_aggregation_query_get_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _aggregation_query_get_helper(retry=retry, timeout=timeout) + + +def test_aggregation_query_get_transaction(): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_aggregation_query"]) + + # Attach the fake GAPIC to a real client. + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + transaction = client.transaction() + + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.count(alias="all") + + aggregation_result = AggregationResult(alias="total", value=5) + response_pb = make_aggregation_query_response( + [aggregation_result], transaction=txn_id + ) + firestore_api.run_aggregation_query.return_value = iter([response_pb]) + retry = None + timeout = None + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + returned = aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +_not_passed = object() + + +def _aggregation_query_stream_w_retriable_exc_helper( + retry=_not_passed, + timeout=None, + transaction=None, + expect_retry=True, +): + from google.api_core import exceptions + from google.api_core import gapic_v1 + from google.cloud.firestore_v1 import _helpers + + if retry is _not_passed: + retry = gapic_v1.method.DEFAULT + + if transaction is not None: + expect_retry = False + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_aggregation_query", "_transport"]) + transport = firestore_api._transport = mock.Mock(spec=["run_aggregation_query"]) + stub = transport.run_aggregation_query = mock.create_autospec( + gapic_v1.method._GapicCallable + ) + stub._retry = mock.Mock(spec=["_predicate"]) + stub._predicate = lambda exc: True # pragma: NO COVER + + # Attach the fake GAPIC to a real client. + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + aggregation_result = AggregationResult(alias="total", value=5) + response_pb = make_aggregation_query_response([aggregation_result]) + + retriable_exc = exceptions.ServiceUnavailable("testing") + + def _stream_w_exception(*_args, **_kw): + yield response_pb + raise retriable_exc + + firestore_api.run_aggregation_query.side_effect = [_stream_w_exception(), iter([])] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + get_response = aggregation_query.stream(transaction=transaction, **kwargs) + + assert isinstance(get_response, types.GeneratorType) + if expect_retry: + returned = list(get_response) + else: + returned = [next(get_response)] + with pytest.raises(exceptions.ServiceUnavailable): + next(get_response) + + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + calls = firestore_api.run_aggregation_query.call_args_list + + if expect_retry: + assert len(calls) == 2 + else: + assert len(calls) == 1 + + if transaction is not None: + expected_transaction_id = transaction.id + else: + expected_transaction_id = None + + assert calls[0] == mock.call( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": expected_transaction_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + if expect_retry: + assert calls[1] == mock.call( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_aggregation_query_stream_w_retriable_exc_w_defaults(): + _aggregation_query_stream_w_retriable_exc_helper() + + +def test_aggregation_query_stream_w_retriable_exc_w_retry(): + retry = mock.Mock(spec=["_predicate"]) + retry._predicate = lambda exc: False + _aggregation_query_stream_w_retriable_exc_helper(retry=retry, expect_retry=False) + + +def test_aggregation_query_stream_w_retriable_exc_w_transaction(): + from google.cloud.firestore_v1 import transaction + + txn = transaction.Transaction(client=mock.Mock(spec=[])) + txn._id = b"DEADBEEF" + _aggregation_query_stream_w_retriable_exc_helper(transaction=txn) + + +def test_aggregation_from_query(): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_aggregation_query"]) + + # Attach the fake GAPIC to a real client. + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_query(parent) + + transaction = client.transaction() + + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + aggregation_result = AggregationResult(alias="total", value=5) + response_pb = make_aggregation_query_response( + [aggregation_result], transaction=txn_id + ) + firestore_api.run_aggregation_query.return_value = iter([response_pb]) + retry = None + timeout = None + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + aggregation_query = query.count(alias="total") + returned = aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py new file mode 100644 index 0000000000..6ed2f74b62 --- /dev/null +++ b/tests/unit/v1/test_async_aggregation.py @@ -0,0 +1,349 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +from datetime import datetime, timezone, timedelta + +from google.cloud.firestore_v1.base_aggregation import ( + CountAggregation, + AggregationResult, +) + +from tests.unit.v1.test__helpers import AsyncIter +from tests.unit.v1.test__helpers import AsyncMock +from tests.unit.v1._test_helpers import ( + make_async_client, + make_async_query, + make_async_aggregation_query, + make_aggregation_query_response, +) + + +_PROJECT = "PROJECT" + + +def test_async_aggregation_query_constructor(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + assert aggregation_query._collection_ref == parent + assert aggregation_query._nested_query == parent._query() + assert len(aggregation_query._aggregations) == 0 + assert aggregation_query._client == client + + +def test_async_aggregation_query_add_aggregation(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.add_aggregation(CountAggregation(alias="all")) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + + +def test_async_aggregation_query_add_aggregations(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.add_aggregations( + [CountAggregation(alias="all"), CountAggregation(alias="total")] + ) + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[1].alias == "total" + + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert isinstance(aggregation_query._aggregations[1], CountAggregation) + + +def test_async_aggregation_query_count(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.count(alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + + +def test_async_aggregation_query_count_twice(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.count(alias="all").count(alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[1].alias == "total" + + assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert isinstance(aggregation_query._aggregations[1], CountAggregation) + + +def test_async_aggregation_query_to_protobuf(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.count(alias="all") + pb = aggregation_query._to_protobuf() + + assert pb.structured_query == parent._query()._to_protobuf() + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == aggregation_query._aggregations[0]._to_protobuf() + + +def test_async_aggregation_query_prep_stream(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.count(alias="all") + + request, kwargs = aggregation_query._prep_stream() + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + } + assert request == expected_request + assert kwargs == {"retry": None} + + +def test_async_aggregation_query_prep_stream_with_transaction(): + client = make_async_client() + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias="all") + + request, kwargs = aggregation_query._prep_stream(transaction=transaction) + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + } + assert request == expected_request + assert kwargs == {"retry": None} + + +@pytest.mark.asyncio +async def _async_aggregation_query_get_helper(retry=None, timeout=None, read_time=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud._helpers import _datetime_to_pb_timestamp + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_aggregation_query"]) + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias="all") + + aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + response_pb = make_aggregation_query_response( + [aggregation_result], read_time=read_time + ) + firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + returned = await aggregation_query.get(**kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + if read_time is not None: + result_datetime = _datetime_to_pb_timestamp(r.read_time) + assert result_datetime == read_time + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_async_aggregation_query_get(): + await _async_aggregation_query_get_helper() + + +@pytest.mark.asyncio +async def test_async_aggregation_query_get_with_readtime(): + from google.cloud._helpers import _datetime_to_pb_timestamp + + one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) + read_time = _datetime_to_pb_timestamp(one_hour_ago) + await _async_aggregation_query_get_helper(read_time=read_time) + + +@pytest.mark.asyncio +async def test_async_aggregation_query_get_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _async_aggregation_query_get_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_async_aggregation_query_get_transaction(): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_aggregation_query"]) + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + transaction = client.transaction() + + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias="all") + + aggregation_result = AggregationResult(alias="total", value=5) + response_pb = make_aggregation_query_response( + [aggregation_result], transaction=txn_id + ) + firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) + retry = None + timeout = None + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + returned = await aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_async_aggregation_from_query(): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_aggregation_query"]) + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + + transaction = client.transaction() + + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + aggregation_result = AggregationResult(alias="total", value=5) + response_pb = make_aggregation_query_response( + [aggregation_result], transaction=txn_id + ) + firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) + retry = None + timeout = None + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + aggregation_query = query.count(alias="total") + returned = await aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 4a9e480a92..0599937cca 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -19,6 +19,7 @@ from tests.unit.v1.test__helpers import AsyncIter from tests.unit.v1.test__helpers import AsyncMock +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT, make_async_client def _make_async_collection_reference(*args, **kwargs): @@ -66,12 +67,36 @@ def test_asynccollectionreference_query_method_matching(): def test_asynccollectionreference_document_name_default(): - client = _make_client() + client = make_async_client() document = client.collection("test").document() # name is random, but assert it is not None assert document.id is not None +def test_async_collection_aggregation_query(): + from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery + + firestore_api = AsyncMock(spec=["create_document", "commit"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("grand-parent", client=client) + + assert isinstance(collection._aggregation_query(), AsyncAggregationQuery) + + +def test_async_collection_count(): + firestore_api = AsyncMock(spec=["create_document", "commit"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("grand-parent", client=client) + + alias = "total" + aggregation_query = collection.count(alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + + @pytest.mark.asyncio async def test_asynccollectionreference_add_auto_assigned(): from google.cloud.firestore_v1.types import document @@ -92,7 +117,7 @@ async def test_asynccollectionreference_add_auto_assigned(): firestore_api.commit.return_value = commit_response create_doc_response = document.Document() firestore_api.create_document.return_value = create_doc_response - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Actually make a collection. @@ -161,7 +186,7 @@ async def _add_helper(retry=None, timeout=None): firestore_api.commit.return_value = commit_response # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Actually make a collection and call add(). @@ -213,7 +238,7 @@ async def test_asynccollectionreference_chunkify(): from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import firestore - client = _make_client() + client = make_async_client() col = client.collection("my-collection") client._firestore_api_internal = mock.Mock(spec=["run_query"]) @@ -221,7 +246,7 @@ async def test_asynccollectionreference_chunkify(): results = [] for index in range(10): name = ( - f"projects/project-project/databases/(default)/" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/" f"documents/my-collection/{index}" ) results.append( @@ -268,7 +293,7 @@ async def _next_page(self): page, self._pages = self._pages[0], self._pages[1:] return Page(self, page, self.item_to_value) - client = _make_client() + client = make_async_client() template = client._database_string + "/documents/{}" document_ids = ["doc-1", "doc-2"] documents = [ @@ -443,16 +468,3 @@ def test_asynccollectionreference_recursive(): col = _make_async_collection_reference("collection") assert isinstance(col.recursive(), AsyncQuery) - - -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - - -def _make_client(): - from google.cloud.firestore_v1.async_client import AsyncClient - - credentials = _make_credentials() - return AsyncClient(project="project-project", credentials=credentials) diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 82f52d0f34..41a5abff56 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -18,6 +18,7 @@ import pytest from tests.unit.v1.test__helpers import AsyncIter, AsyncMock +from tests.unit.v1._test_helpers import make_async_client def _make_async_document_reference(*args, **kwargs): @@ -76,7 +77,7 @@ async def _create_helper(retry=None, timeout=None): firestore_api.commit.return_value = _make_commit_repsonse() # Attach the fake GAPIC to a real client. - client = _make_client("dignity") + client = make_async_client("dignity") client._firestore_api_internal = firestore_api # Actually make a document and call create(). @@ -130,7 +131,7 @@ async def test_asyncdocumentreference_create_empty(): ) # Attach the fake GAPIC to a real client. - client = _make_client("dignity") + client = make_async_client("dignity") client._firestore_api_internal = firestore_api client.get_all = mock.MagicMock() client.get_all.exists.return_value = True @@ -175,7 +176,7 @@ async def _set_helper(merge=False, retry=None, timeout=None, **option_kwargs): firestore_api.commit.return_value = _make_commit_repsonse() # Attach the fake GAPIC to a real client. - client = _make_client("db-dee-bee") + client = make_async_client("db-dee-bee") client._firestore_api_internal = firestore_api # Actually make a document and call create(). @@ -244,7 +245,7 @@ async def _update_helper(retry=None, timeout=None, **option_kwargs): firestore_api.commit.return_value = _make_commit_repsonse() # Attach the fake GAPIC to a real client. - client = _make_client("potato-chip") + client = make_async_client("potato-chip") client._firestore_api_internal = firestore_api # Actually make a document and call create(). @@ -320,7 +321,7 @@ async def test_asyncdocumentreference_empty_update(): firestore_api.commit.return_value = _make_commit_repsonse() # Attach the fake GAPIC to a real client. - client = _make_client("potato-chip") + client = make_async_client("potato-chip") client._firestore_api_internal = firestore_api # Actually make a document and call create(). @@ -341,7 +342,7 @@ async def _delete_helper(retry=None, timeout=None, **option_kwargs): firestore_api.commit.return_value = _make_commit_repsonse() # Attach the fake GAPIC to a real client. - client = _make_client("donut-base") + client = make_async_client("donut-base") client._firestore_api_internal = firestore_api kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -421,7 +422,7 @@ async def _get_helper( response.found.create_time = create_time response.found.update_time = update_time - client = _make_client("donut-base") + client = make_async_client("donut-base") client._firestore_api_internal = firestore_api document_reference = _make_async_document_reference( "where", "we-are", client=client @@ -550,7 +551,7 @@ async def __aiter__(self, **_): firestore_api.mock_add_spec(spec=["list_collection_ids"]) firestore_api.list_collection_ids.return_value = Pager() - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -594,16 +595,3 @@ async def test_asyncdocumentreference_collections_w_retry_timeout(): @pytest.mark.asyncio async def test_asyncdocumentreference_collections_w_page_size(): await _collections_helper(page_size=10) - - -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - - -def _make_client(project="project-project"): - from google.cloud.firestore_v1.async_client import AsyncClient - - credentials = _make_credentials() - return AsyncClient(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 4b7b83cede..b74a215c3f 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -19,19 +19,17 @@ from tests.unit.v1.test__helpers import AsyncIter from tests.unit.v1.test__helpers import AsyncMock -from tests.unit.v1.test_base_query import _make_credentials from tests.unit.v1.test_base_query import _make_query_response from tests.unit.v1.test_base_query import _make_cursor_pb - - -def _make_async_query(*args, **kwargs): - from google.cloud.firestore_v1.async_query import AsyncQuery - - return AsyncQuery(*args, **kwargs) +from tests.unit.v1._test_helpers import ( + DEFAULT_TEST_PROJECT, + make_async_client, + make_async_query, +) def test_asyncquery_constructor(): - query = _make_async_query(mock.sentinel.parent) + query = make_async_query(mock.sentinel.parent) assert query._parent is mock.sentinel.parent assert query._projection is None assert query._field_filters == () @@ -50,7 +48,7 @@ async def _get_helper(retry=None, timeout=None): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -66,7 +64,7 @@ async def _get_helper(retry=None, timeout=None): kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) returned = await query.get(**kwargs) assert isinstance(returned, list) @@ -112,7 +110,7 @@ async def test_asyncquery_get_limit_to_last(): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -130,7 +128,7 @@ async def test_asyncquery_get_limit_to_last(): firestore_api.run_query.return_value = AsyncIter([response_pb2, response_pb]) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) query = query.order_by( "snooze", direction=firestore.AsyncQuery.DESCENDING ).limit_to_last(2) @@ -164,7 +162,7 @@ async def test_asyncquery_get_limit_to_last(): @pytest.mark.asyncio async def test_asyncquery_chunkify_w_empty(): - client = _make_client() + client = make_async_client() firestore_api = AsyncMock(spec=["run_query"]) firestore_api.run_query.return_value = AsyncIter([]) client._firestore_api_internal = firestore_api @@ -182,10 +180,10 @@ async def test_asyncquery_chunkify_w_chunksize_lt_limit(): from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import firestore - client = _make_client() + client = make_async_client() firestore_api = AsyncMock(spec=["run_query"]) doc_ids = [ - f"projects/project-project/databases/(default)/documents/asdf/{index}" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/asdf/{index}" for index in range(5) ] responses1 = [ @@ -230,14 +228,14 @@ async def test_asyncquery_chunkify_w_chunksize_gt_limit(): from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import firestore - client = _make_client() + client = make_async_client() firestore_api = AsyncMock(spec=["run_query"]) responses = [ firestore.RunQueryResponse( document=document.Document( name=( - f"projects/project-project/databases/(default)/" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/" f"documents/asdf/{index}" ), ), @@ -265,7 +263,7 @@ async def _stream_helper(retry=None, timeout=None): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -280,7 +278,7 @@ async def _stream_helper(retry=None, timeout=None): kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) get_response = query.stream(**kwargs) @@ -321,11 +319,11 @@ async def test_asyncquery_stream_w_retry_timeout(): @pytest.mark.asyncio async def test_asyncquery_stream_with_limit_to_last(): # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() # Make a **real** collection reference as parent. parent = client.collection("dee") # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) query = query.limit_to_last(2) stream_response = query.stream() @@ -340,7 +338,7 @@ async def test_asyncquery_stream_with_transaction(): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Create a real-ish transaction for this client. @@ -359,7 +357,7 @@ async def test_asyncquery_stream_with_transaction(): firestore_api.run_query.return_value = AsyncIter([response_pb]) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) get_response = query.stream(transaction=transaction) assert isinstance(get_response, types.AsyncGeneratorType) returned = [x async for x in get_response] @@ -388,12 +386,12 @@ async def test_asyncquery_stream_no_results(): firestore_api.run_query.return_value = run_query_response # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. parent = client.collection("dah", "dah", "dum") - query = _make_async_query(parent) + query = make_async_query(parent) get_response = query.stream() assert isinstance(get_response, types.AsyncGeneratorType) @@ -421,12 +419,12 @@ async def test_asyncquery_stream_second_response_in_empty_stream(): firestore_api.run_query.return_value = run_query_response # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. parent = client.collection("dah", "dah", "dum") - query = _make_async_query(parent) + query = make_async_query(parent) get_response = query.stream() assert isinstance(get_response, types.AsyncGeneratorType) @@ -450,7 +448,7 @@ async def test_asyncquery_stream_with_skipped_results(): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -465,7 +463,7 @@ async def test_asyncquery_stream_with_skipped_results(): firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) get_response = query.stream() assert isinstance(get_response, types.AsyncGeneratorType) returned = [x async for x in get_response] @@ -492,7 +490,7 @@ async def test_asyncquery_stream_empty_after_first_response(): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -507,7 +505,7 @@ async def test_asyncquery_stream_empty_after_first_response(): firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) get_response = query.stream() assert isinstance(get_response, types.AsyncGeneratorType) returned = [x async for x in get_response] @@ -534,7 +532,7 @@ async def test_asyncquery_stream_w_collection_group(): firestore_api = AsyncMock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -550,7 +548,7 @@ async def test_asyncquery_stream_w_collection_group(): firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) # Execute the query and check the response. - query = _make_async_query(parent) + query = make_async_query(parent) query._all_descendants = True get_response = query.stream() assert isinstance(get_response, types.AsyncGeneratorType) @@ -605,7 +603,7 @@ async def _get_partitions_helper(retry=None, timeout=None): firestore_api = AsyncMock(spec=["partition_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_async_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -663,7 +661,7 @@ async def test_asynccollectiongroup_get_partitions_w_retry_timeout(): @pytest.mark.asyncio async def test_asynccollectiongroup_get_partitions_w_filter(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_async_client() parent = client.collection("charles") # Make a query that fails to partition @@ -675,7 +673,7 @@ async def test_asynccollectiongroup_get_partitions_w_filter(): @pytest.mark.asyncio async def test_asynccollectiongroup_get_partitions_w_projection(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_async_client() parent = client.collection("charles") # Make a query that fails to partition @@ -687,7 +685,7 @@ async def test_asynccollectiongroup_get_partitions_w_projection(): @pytest.mark.asyncio async def test_asynccollectiongroup_get_partitions_w_limit(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_async_client() parent = client.collection("charles") # Make a query that fails to partition @@ -699,17 +697,10 @@ async def test_asynccollectiongroup_get_partitions_w_limit(): @pytest.mark.asyncio async def test_asynccollectiongroup_get_partitions_w_offset(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_async_client() parent = client.collection("charles") # Make a query that fails to partition query = _make_async_collection_group(parent).offset(10) with pytest.raises(ValueError): [i async for i in query.get_partitions(2)] - - -def _make_client(project="project-project"): - from google.cloud.firestore_v1.async_client import AsyncClient - - credentials = _make_credentials() - return AsyncClient(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index c17fb31eaf..c4dbe72106 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -15,6 +15,8 @@ import mock import pytest +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT + def _make_base_collection_reference(*args, **kwargs): from google.cloud.firestore_v1.base_collection import BaseCollectionReference @@ -402,4 +404,4 @@ def _make_client(): from google.cloud.firestore_v1.client import Client credentials = _make_credentials() - return Client(project="project-project", credentials=credentials) + return Client(project=DEFAULT_TEST_PROJECT, credentials=credentials) diff --git a/tests/unit/v1/test_base_document.py b/tests/unit/v1/test_base_document.py index d3a59d5adf..b4ed2730f8 100644 --- a/tests/unit/v1/test_base_document.py +++ b/tests/unit/v1/test_base_document.py @@ -16,6 +16,8 @@ import mock import pytest +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT + def _make_base_document_reference(*args, **kwargs): from google.cloud.firestore_v1.base_document import BaseDocumentReference @@ -434,7 +436,7 @@ def _make_credentials(): return mock.Mock(spec=google.auth.credentials.Credentials) -def _make_client(project="project-project"): +def _make_client(project=DEFAULT_TEST_PROJECT): from google.cloud.firestore_v1.client import Client credentials = _make_credentials() diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 790f170235..818e3e7b88 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -17,6 +17,8 @@ import mock import pytest +from tests.unit.v1._test_helpers import make_client + def _make_base_query(*args, **kwargs): from google.cloud.firestore_v1.base_query import BaseQuery @@ -395,13 +397,13 @@ def test_basequery_limit_to_last(): def test_basequery__resolve_chunk_size(): # With a global limit - query = _make_client().collection("asdf").limit(5) + query = make_client().collection("asdf").limit(5) assert query._resolve_chunk_size(3, 10) == 2 assert query._resolve_chunk_size(3, 1) == 1 assert query._resolve_chunk_size(3, 2) == 2 # With no limit - query = _make_client().collection("asdf")._query() + query = make_client().collection("asdf")._query() assert query._resolve_chunk_size(3, 10) == 10 assert query._resolve_chunk_size(3, 1) == 1 assert query._resolve_chunk_size(3, 2) == 2 @@ -1267,7 +1269,7 @@ class DerivedQuery(BaseQuery): def _get_collection_reference_class(): return CollectionReference - query = DerivedQuery(_make_client().collection("asdf")) + query = DerivedQuery(make_client().collection("asdf")) assert isinstance(query.recursive().recursive(), DerivedQuery) @@ -1471,7 +1473,7 @@ def test__query_response_to_snapshot_response(): from google.cloud.firestore_v1.base_query import _query_response_to_snapshot from google.cloud.firestore_v1.document import DocumentSnapshot - client = _make_client() + client = make_client() collection = client.collection("a", "b", "c") _, expected_prefix = collection._parent_info() @@ -1519,7 +1521,7 @@ def test__collection_group_query_response_to_snapshot_response(): _collection_group_query_response_to_snapshot, ) - client = _make_client() + client = make_client() collection = client.collection("a", "b", "c") other_collection = client.collection("a", "b", "d") to_match = other_collection.document("gigantic") @@ -1536,19 +1538,6 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - - -def _make_client(project="project-project"): - from google.cloud.firestore_v1.client import Client - - credentials = _make_credentials() - return Client(project=project, credentials=credentials) - - def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query diff --git a/tests/unit/v1/test_bundle.py b/tests/unit/v1/test_bundle.py index 6b480f84c8..8508a79b21 100644 --- a/tests/unit/v1/test_bundle.py +++ b/tests/unit/v1/test_bundle.py @@ -24,12 +24,16 @@ from google.cloud.firestore_v1 import query as query_mod from tests.unit.v1 import _test_helpers +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT + class _CollectionQueryMixin: # Path to each document where we don't specify custom collection names or # document Ids - doc_key: str = "projects/project-project/databases/(default)/documents/col/doc" + doc_key: str = ( + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc" + ) @staticmethod def build_results_iterable(items): @@ -206,13 +210,13 @@ def test_add_query(self): assert bundle.named_queries.get("asdf") is not None assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-1" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-1" ] is not None ) assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-2" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-2" ] is not None ) @@ -301,7 +305,7 @@ def test_get_document(self): assert ( _helpers._get_document_from_bundle( bundle, - document_id="projects/project-project/databases/(default)/documents/col/doc-1", + document_id=f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-1", ) is not None ) @@ -309,7 +313,7 @@ def test_get_document(self): assert ( _helpers._get_document_from_bundle( bundle, - document_id="projects/project-project/databases/(default)/documents/col/doc-0", + document_id=f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-0", ) is None ) @@ -350,13 +354,13 @@ def test_async_query(self): assert bundle.named_queries.get("asdf") is not None assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-1" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-1" ] is not None ) assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-2" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-2" ] is not None ) @@ -409,13 +413,13 @@ def test_build_round_trip_emojis(self): assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-1" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-1" ].snapshot._data["smile"] == smile ) assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-2" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-2" ].snapshot._data["compound"] == mermaid ) @@ -437,13 +441,13 @@ def test_build_round_trip_more_unicode(self): assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-1" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-1" ].snapshot._data["bano"] == bano ) assert ( bundle.documents[ - "projects/project-project/databases/(default)/documents/col/doc-2" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/col/doc-2" ].snapshot._data["international"] == chinese_characters ) diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 36492722e0..04e6e21985 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -16,6 +16,8 @@ import mock +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT + def _make_collection_reference(*args, **kwargs): from google.cloud.firestore_v1.collection import CollectionReference @@ -47,6 +49,39 @@ def test_query_method_matching(): assert query_methods <= collection_methods +def test_collection_aggregation_query(): + from google.cloud.firestore_v1.aggregation import AggregationQuery + + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + assert isinstance(collection._aggregation_query(), AggregationQuery) + + +def test_collection_count(): + + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + alias = "total" + aggregation_query = collection.count(alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + + def test_constructor(): collection_id1 = "rooms" document_id = "roomA" @@ -387,7 +422,7 @@ def test_chunkify(): results.append( RunQueryResponse( document=Document( - name=f"projects/project-project/databases/(default)/documents/my-collection/{index}", + name=f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/my-collection/{index}", ), ), ) diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index df52a7c3e6..210591b430 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -16,6 +16,8 @@ import mock import pytest +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT + def _make_document_reference(*args, **kwargs): from google.cloud.firestore_v1.document import DocumentReference @@ -572,7 +574,7 @@ def _make_credentials(): return mock.Mock(spec=google.auth.credentials.Credentials) -def _make_client(project="project-project"): +def _make_client(project=DEFAULT_TEST_PROJECT): from google.cloud.firestore_v1.client import Client credentials = _make_credentials() diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index f82036c4be..3e529d9a4d 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -17,19 +17,14 @@ import mock import pytest -from tests.unit.v1.test_base_query import _make_credentials from tests.unit.v1.test_base_query import _make_cursor_pb from tests.unit.v1.test_base_query import _make_query_response - -def _make_query(*args, **kwargs): - from google.cloud.firestore_v1.query import Query - - return Query(*args, **kwargs) +from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT, make_client, make_query def test_query_constructor(): - query = _make_query(mock.sentinel.parent) + query = make_query(mock.sentinel.parent) assert query._parent is mock.sentinel.parent assert query._projection is None assert query._field_filters == () @@ -48,7 +43,7 @@ def _query_get_helper(retry=None, timeout=None): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -64,7 +59,7 @@ def _query_get_helper(retry=None, timeout=None): kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) returned = query.get(**kwargs) assert isinstance(returned, list) @@ -107,7 +102,7 @@ def test_query_get_limit_to_last(): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -125,7 +120,7 @@ def test_query_get_limit_to_last(): firestore_api.run_query.return_value = iter([response_pb2, response_pb]) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) query = query.order_by( "snooze", direction=firestore.Query.DESCENDING ).limit_to_last(2) @@ -155,7 +150,7 @@ def test_query_get_limit_to_last(): def test_query_chunkify_w_empty(): - client = _make_client() + client = make_client() firestore_api = mock.Mock(spec=["run_query"]) firestore_api.run_query.return_value = iter([]) client._firestore_api_internal = firestore_api @@ -170,10 +165,10 @@ def test_query_chunkify_w_chunksize_lt_limit(): from google.cloud.firestore_v1.types.document import Document from google.cloud.firestore_v1.types.firestore import RunQueryResponse - client = _make_client() + client = make_client() firestore_api = mock.Mock(spec=["run_query"]) doc_ids = [ - f"projects/project-project/databases/(default)/documents/asdf/{index}" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/asdf/{index}" for index in range(5) ] responses1 = [ @@ -215,10 +210,10 @@ def test_query_chunkify_w_chunksize_gt_limit(): from google.cloud.firestore_v1.types.document import Document from google.cloud.firestore_v1.types.firestore import RunQueryResponse - client = _make_client() + client = make_client() firestore_api = mock.Mock(spec=["run_query"]) doc_ids = [ - f"projects/project-project/databases/(default)/documents/asdf/{index}" + f"projects/{DEFAULT_TEST_PROJECT}/databases/(default)/documents/asdf/{index}" for index in range(5) ] responses = [ @@ -246,7 +241,7 @@ def _query_stream_helper(retry=None, timeout=None): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -261,7 +256,7 @@ def _query_stream_helper(retry=None, timeout=None): kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) get_response = query.stream(**kwargs) @@ -299,11 +294,11 @@ def test_query_stream_w_retry_timeout(): def test_query_stream_with_limit_to_last(): # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() # Make a **real** collection reference as parent. parent = client.collection("dee") # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) query = query.limit_to_last(2) stream_response = query.stream() @@ -317,7 +312,7 @@ def test_query_stream_with_transaction(): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Create a real-ish transaction for this client. @@ -336,7 +331,7 @@ def test_query_stream_with_transaction(): firestore_api.run_query.return_value = iter([response_pb]) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) get_response = query.stream(transaction=transaction) assert isinstance(get_response, types.GeneratorType) returned = list(get_response) @@ -364,12 +359,12 @@ def test_query_stream_no_results(): firestore_api.run_query.return_value = run_query_response # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. parent = client.collection("dah", "dah", "dum") - query = _make_query(parent) + query = make_query(parent) get_response = query.stream() assert isinstance(get_response, types.GeneratorType) @@ -397,12 +392,12 @@ def test_query_stream_second_response_in_empty_stream(): firestore_api.run_query.return_value = run_query_response # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. parent = client.collection("dah", "dah", "dum") - query = _make_query(parent) + query = make_query(parent) get_response = query.stream() assert isinstance(get_response, types.GeneratorType) @@ -425,7 +420,7 @@ def test_query_stream_with_skipped_results(): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -440,7 +435,7 @@ def test_query_stream_with_skipped_results(): firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) get_response = query.stream() assert isinstance(get_response, types.GeneratorType) returned = list(get_response) @@ -466,7 +461,7 @@ def test_query_stream_empty_after_first_response(): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -481,7 +476,7 @@ def test_query_stream_empty_after_first_response(): firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) get_response = query.stream() assert isinstance(get_response, types.GeneratorType) returned = list(get_response) @@ -507,7 +502,7 @@ def test_query_stream_w_collection_group(): firestore_api = mock.Mock(spec=["run_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -523,7 +518,7 @@ def test_query_stream_w_collection_group(): firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) query._all_descendants = True get_response = query.stream() assert isinstance(get_response, types.GeneratorType) @@ -574,7 +569,7 @@ def _query_stream_w_retriable_exc_helper( stub._predicate = lambda exc: True # pragma: NO COVER # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -595,7 +590,7 @@ def _stream_w_exception(*_args, **_kw): kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - query = _make_query(parent) + query = make_query(parent) get_response = query.stream(transaction=transaction, **kwargs) @@ -669,7 +664,7 @@ def test_query_stream_w_retriable_exc_w_transaction(): @mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True) def test_query_on_snapshot(watch): - query = _make_query(mock.sentinel.parent) + query = make_query(mock.sentinel.parent) query.on_snapshot(None) watch.for_query.assert_called_once() @@ -705,7 +700,7 @@ def _collection_group_get_partitions_helper(retry=None, timeout=None): firestore_api = mock.Mock(spec=["partition_query"]) # Attach the fake GAPIC to a real client. - client = _make_client() + client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. @@ -761,7 +756,7 @@ def test_collection_group_get_partitions_w_retry_timeout(): def test_collection_group_get_partitions_w_filter(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_client() parent = client.collection("charles") # Make a query that fails to partition @@ -772,7 +767,7 @@ def test_collection_group_get_partitions_w_filter(): def test_collection_group_get_partitions_w_projection(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_client() parent = client.collection("charles") # Make a query that fails to partition @@ -783,7 +778,7 @@ def test_collection_group_get_partitions_w_projection(): def test_collection_group_get_partitions_w_limit(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_client() parent = client.collection("charles") # Make a query that fails to partition @@ -794,17 +789,10 @@ def test_collection_group_get_partitions_w_limit(): def test_collection_group_get_partitions_w_offset(): # Make a **real** collection reference as parent. - client = _make_client() + client = make_client() parent = client.collection("charles") # Make a query that fails to partition query = _make_collection_group(parent).offset(10) with pytest.raises(ValueError): list(query.get_partitions(2)) - - -def _make_client(project="project-project"): - from google.cloud.firestore_v1.client import Client - - credentials = _make_credentials() - return Client(project=project, credentials=credentials)