From 49e7ce3358357a42145460265e37a8b639726e81 Mon Sep 17 00:00:00 2001 From: HemangChothani <50404902+HemangChothani@users.noreply.github.com> Date: Tue, 16 Jul 2019 00:24:07 +0530 Subject: [PATCH] Add 'Transaction.get' / 'Transaction.get_all'. (#8628) Closes #6557. --- .../google/cloud/firestore_v1/transaction.py | 33 +++++++++++++++++ firestore/tests/unit/v1/test_transaction.py | 37 ++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/firestore/google/cloud/firestore_v1/transaction.py b/firestore/google/cloud/firestore_v1/transaction.py index 1e28cc9ac431..9d4068c75a88 100644 --- a/firestore/google/cloud/firestore_v1/transaction.py +++ b/firestore/google/cloud/firestore_v1/transaction.py @@ -23,6 +23,8 @@ from google.api_core import exceptions from google.cloud.firestore_v1 import batch from google.cloud.firestore_v1 import types +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.query import Query MAX_ATTEMPTS = 5 @@ -200,6 +202,37 @@ def _commit(self): self._clean_up() return list(commit_response.write_results) + def get_all(self, references): + """Retrieves multiple documents from Firestore. + + Args: + references (List[.DocumentReference, ...]): Iterable of document + references to be retrieved. + + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + return self._client.get_all(references, transaction=self._id) + + def get(self, ref_or_query): + """ + Retrieve a document or a query result from the database. + Args: + ref_or_query The document references or query object to return. + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + if isinstance(ref_or_query, DocumentReference): + return self._client.get_all([ref_or_query], transaction=self._id) + elif isinstance(ref_or_query, Query): + return ref_or_query.stream(transaction=self._id) + else: + raise ValueError( + 'Value for argument "ref_or_query" must be a DocumentReference or a Query.' + ) + class _Transactional(object): """Provide a callable object to use as a transactional decorater. diff --git a/firestore/tests/unit/v1/test_transaction.py b/firestore/tests/unit/v1/test_transaction.py index ed578ad3eea6..8cae24a23831 100644 --- a/firestore/tests/unit/v1/test_transaction.py +++ b/firestore/tests/unit/v1/test_transaction.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest - import mock @@ -329,6 +328,42 @@ def test__commit_failure(self): metadata=client._rpc_metadata, ) + def test_get_all(self): + client = mock.Mock(spec=["get_all"]) + transaction = self._make_one(client) + ref1, ref2 = mock.Mock(), mock.Mock() + result = transaction.get_all([ref1, ref2]) + client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction.id) + self.assertIs(result, client.get_all.return_value) + + def test_get_document_ref(self): + from google.cloud.firestore_v1.document import DocumentReference + + client = mock.Mock(spec=["get_all"]) + transaction = self._make_one(client) + ref = DocumentReference("documents", "doc-id") + result = transaction.get(ref) + client.get_all.assert_called_once_with([ref], transaction=transaction.id) + self.assertIs(result, client.get_all.return_value) + + def test_get_w_query(self): + from google.cloud.firestore_v1.query import Query + + client = mock.Mock(spec=[]) + transaction = self._make_one(client) + query = Query(parent=mock.Mock(spec=[])) + query.stream = mock.MagicMock() + result = transaction.get(query) + query.stream.assert_called_once_with(transaction=transaction.id) + self.assertIs(result, query.stream.return_value) + + def test_get_failure(self): + client = _make_client() + transaction = self._make_one(client) + ref_or_query = object() + with self.assertRaises(ValueError): + transaction.get(ref_or_query) + class Test_Transactional(unittest.TestCase): @staticmethod