Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve AsyncQuery typing #782

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from google.cloud.firestore_v1.transaction import Transaction


class AsyncCollectionReference(BaseCollectionReference):
class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
"""A reference to a collection in a Firestore database.
The collection may already exist or this class can facilitate creation
Expand Down
12 changes: 7 additions & 5 deletions google/cloud/firestore_v1/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,15 @@ def _rpc_metadata(self):

return self._rpc_metadata_internal

def collection(self, *collection_path) -> BaseCollectionReference:
def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]:
raise NotImplementedError

def collection_group(self, collection_id: str) -> BaseQuery:
raise NotImplementedError

def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference:
def _get_collection_reference(
self, collection_id: str
) -> BaseCollectionReference[BaseQuery]:
"""Checks validity of collection_id and then uses subclasses collection implementation.
Args:
Expand Down Expand Up @@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]:

def recursive_delete(
self,
reference: Union[BaseCollectionReference, BaseDocumentReference],
reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference],
bulk_writer: Optional["BulkWriter"] = None, # type: ignore
) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -459,8 +461,8 @@ def collections(
retry: retries.Retry = None,
timeout: float = None,
) -> Union[
AsyncGenerator[BaseCollectionReference, Any],
Generator[BaseCollectionReference, Any, Any],
AsyncGenerator[BaseCollectionReference[BaseQuery], Any],
Generator[BaseCollectionReference[BaseQuery], Any, Any],
]:
raise NotImplementedError

Expand Down
27 changes: 14 additions & 13 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AsyncGenerator,
Coroutine,
Generator,
Generic,
AsyncIterator,
Iterator,
Iterable,
Expand All @@ -38,13 +39,13 @@

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.base_query import QueryType
from google.cloud.firestore_v1.transaction import Transaction

_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"


class BaseCollectionReference(object):
class BaseCollectionReference(Generic[QueryType]):
"""A reference to a collection in a Firestore database.

The collection may already exist or this class can facilitate creation
Expand Down Expand Up @@ -108,7 +109,7 @@ def parent(self):
parent_path = self._path[:-1]
return self._client.document(*parent_path)

def _query(self) -> BaseQuery:
def _query(self) -> QueryType:
raise NotImplementedError

def _aggregation_query(self) -> BaseAggregationQuery:
Expand Down Expand Up @@ -215,10 +216,10 @@ def list_documents(
]:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
def recursive(self) -> QueryType:
return self._query().recursive()

def select(self, field_paths: Iterable[str]) -> BaseQuery:
def select(self, field_paths: Iterable[str]) -> QueryType:
"""Create a "select" query with this collection as parent.

See
Expand All @@ -244,7 +245,7 @@ def where(
value=None,
*,
filter=None
) -> BaseQuery:
) -> QueryType:
"""Create a "where" query with this collection as parent.

See
Expand Down Expand Up @@ -290,7 +291,7 @@ def where(
else:
return query.where(filter=filter)

def order_by(self, field_path: str, **kwargs) -> BaseQuery:
def order_by(self, field_path: str, **kwargs) -> QueryType:
"""Create an "order by" query with this collection as parent.

See
Expand All @@ -312,7 +313,7 @@ def order_by(self, field_path: str, **kwargs) -> BaseQuery:
query = self._query()
return query.order_by(field_path, **kwargs)

def limit(self, count: int) -> BaseQuery:
def limit(self, count: int) -> QueryType:
"""Create a limited query with this collection as parent.

.. note::
Expand Down Expand Up @@ -355,7 +356,7 @@ def limit_to_last(self, count: int):
query = self._query()
return query.limit_to_last(count)

def offset(self, num_to_skip: int) -> BaseQuery:
def offset(self, num_to_skip: int) -> QueryType:
"""Skip to an offset in a query with this collection as parent.

See
Expand All @@ -375,7 +376,7 @@ def offset(self, num_to_skip: int) -> BaseQuery:

def start_at(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""Start query at a cursor with this collection as parent.

See
Expand All @@ -398,7 +399,7 @@ def start_at(

def start_after(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""Start query after a cursor with this collection as parent.

See
Expand All @@ -421,7 +422,7 @@ def start_after(

def end_before(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""End query before a cursor with this collection as parent.

See
Expand All @@ -444,7 +445,7 @@ def end_before(

def end_at(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""End query at a cursor with this collection as parent.

See
Expand Down
49 changes: 29 additions & 20 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
)

Expand Down Expand Up @@ -102,6 +103,8 @@

_not_passed = object()

QueryType = TypeVar("QueryType", bound="BaseQuery")


class BaseFilter(abc.ABC):
"""Base class for Filters"""
Expand Down Expand Up @@ -319,7 +322,7 @@ def _client(self):
"""
return self._parent._client

def select(self, field_paths: Iterable[str]) -> "BaseQuery":
def select(self: QueryType, field_paths: Iterable[str]) -> QueryType:
"""Project documents matching query to a limited set of fields.

See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
Expand Down Expand Up @@ -354,7 +357,7 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery":
return self._copy(projection=new_projection)

def _copy(
self,
self: QueryType,
*,
projection: Optional[query.StructuredQuery.Projection] = _not_passed,
field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed,
Expand All @@ -366,7 +369,7 @@ def _copy(
end_at: Optional[Tuple[dict, bool]] = _not_passed,
all_descendants: Optional[bool] = _not_passed,
recursive: Optional[bool] = _not_passed,
) -> "BaseQuery":
) -> QueryType:
return self.__class__(
self._parent,
projection=self._evaluate_param(projection, self._projection),
Expand All @@ -389,13 +392,13 @@ def _evaluate_param(self, value, fallback_value):
return value if value is not _not_passed else fallback_value

def where(
self,
self: QueryType,
field_path: Optional[str] = None,
op_string: Optional[str] = None,
value=None,
*,
filter=None,
) -> "BaseQuery":
) -> QueryType:
"""Filter the query on a field.

See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
Expand Down Expand Up @@ -492,7 +495,9 @@ def _make_order(field_path, direction) -> StructuredQuery.Order:
direction=_enum_from_direction(direction),
)

def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
def order_by(
self: QueryType, field_path: str, direction: str = ASCENDING
) -> QueryType:
"""Modify the query to add an order clause on a specific field.

See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
Expand Down Expand Up @@ -526,7 +531,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
new_orders = self._orders + (order_pb,)
return self._copy(orders=new_orders)

def limit(self, count: int) -> "BaseQuery":
def limit(self: QueryType, count: int) -> QueryType:
"""Limit a query to return at most `count` matching results.

If the current query already has a `limit` set, this will override it.
Expand All @@ -545,7 +550,7 @@ def limit(self, count: int) -> "BaseQuery":
"""
return self._copy(limit=count, limit_to_last=False)

def limit_to_last(self, count: int) -> "BaseQuery":
def limit_to_last(self: QueryType, count: int) -> QueryType:
"""Limit a query to return the last `count` matching results.
If the current query already has a `limit_to_last`
set, this will override it.
Expand All @@ -570,7 +575,7 @@ def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int:
return max(self._limit - num_loaded, 0)
return chunk_size

def offset(self, num_to_skip: int) -> "BaseQuery":
def offset(self: QueryType, num_to_skip: int) -> QueryType:
"""Skip to an offset in a query.

If the current query already has specified an offset, this will
Expand Down Expand Up @@ -601,11 +606,11 @@ def _check_snapshot(self, document_snapshot) -> None:
raise ValueError("Cannot use snapshot from another collection as a cursor.")

def _cursor_helper(
self,
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
before: bool,
start: bool,
) -> "BaseQuery":
) -> QueryType:
"""Set values to be used for a ``start_at`` or ``end_at`` cursor.

The values will later be used in a query protobuf.
Expand Down Expand Up @@ -658,8 +663,9 @@ def _cursor_helper(
return self._copy(**query_kwargs)

def start_at(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""Start query results at a particular document value.

The result set will **include** the document specified by
Expand Down Expand Up @@ -690,8 +696,9 @@ def start_at(
return self._cursor_helper(document_fields_or_snapshot, before=True, start=True)

def start_after(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""Start query results after a particular document value.

The result set will **exclude** the document specified by
Expand Down Expand Up @@ -723,8 +730,9 @@ def start_after(
)

def end_before(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""End query results before a particular document value.

The result set will **exclude** the document specified by
Expand Down Expand Up @@ -756,8 +764,9 @@ def end_before(
)

def end_at(
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
) -> "BaseQuery":
self: QueryType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it shouldn't be necessary to include type annotations for self (mypy doesn't require it, and we typically don't include them in other GCP libraries).

Is there a reason it's needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback! I wanted to make it explicit on the function signature that the return type was of the same type as the object calling it.

document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
) -> QueryType:
"""End query results at a particular document value.

The result set will **include** the document specified by
Expand Down Expand Up @@ -1003,7 +1012,7 @@ def stream(
def on_snapshot(self, callback) -> NoReturn:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
def recursive(self: QueryType) -> QueryType:
"""Returns a copy of this query whose iterator will yield all matching
documents as well as each of their descendent subcollections and documents.

Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from google.cloud.firestore_v1.transaction import Transaction


class CollectionReference(BaseCollectionReference):
class CollectionReference(BaseCollectionReference[query_mod.Query]):
"""A reference to a collection in a Firestore database.
The collection may already exist or this class can facilitate creation
Expand Down