diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 552d296e64..c3091e75aa 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -236,16 +236,29 @@ def where(self, field_path: str, op_string: str, value) -> BaseQuery: field_path (str): A field path (``.``-delimited list of field names) for the field to filter on. op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=`` - and ``>``. + Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, + and ``in``. value (Any): The value to compare the field against in the filter. If ``value`` is :data:`None` or a NaN, then ``==`` is the only - allowed operation. + allowed operation. If ``op_string`` is ``in``, ``value`` + must be a sequence of values. Returns: :class:`~google.cloud.firestore_v1.query.Query`: A filtered query. """ + if field_path == "__name__" and op_string == "in": + wrapped_names = [] + + for name in value: + + if isinstance(name, str): + name = self.document(name) + + wrapped_names.append(name) + + value = wrapped_names + query = self._query() return query.where(field_path, op_string, value) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 8d4b783336..c17fb31eaf 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -217,6 +217,45 @@ def test_basecollectionreference_where(mock_query): assert query == mock_query.where.return_value +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_where_w___name___w_value_as_list_of_str(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + client = _make_client() + collection = _make_base_collection_reference("collection", client=client) + field_path = "__name__" + op_string = "in" + names = ["hello", "world"] + + query = collection.where(field_path, op_string, names) + + expected_refs = [collection.document(name) for name in names] + mock_query.where.assert_called_once_with(field_path, op_string, expected_refs) + assert query == mock_query.where.return_value + + +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_where_w___name___w_value_as_list_of_docref(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + client = _make_client() + collection = _make_base_collection_reference("collection", client=client) + field_path = "__name__" + op_string = "in" + refs = [collection.document("hello"), collection.document("world")] + + query = collection.where(field_path, op_string, refs) + + mock_query.where.assert_called_once_with(field_path, op_string, refs) + assert query == mock_query.where.return_value + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) def test_basecollectionreference_order_by(mock_query): from google.cloud.firestore_v1.base_query import BaseQuery