From 8779d8dd9723cc08caafb47be64d5c7f3e0462e0 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 3 Jul 2019 15:00:39 +0530 Subject: [PATCH 1/4] added get and get_all method in transaction class --- .../google/cloud/firestore_v1/transaction.py | 33 ++++ firestore/tests/unit/v1/test_transaction.py | 158 +++++++++++++++++- 2 files changed, 190 insertions(+), 1 deletion(-) diff --git a/firestore/google/cloud/firestore_v1/transaction.py b/firestore/google/cloud/firestore_v1/transaction.py index 1e28cc9ac431..64fcaf1ad2d5 100644 --- a/firestore/google/cloud/firestore_v1/transaction.py +++ b/firestore/google/cloud/firestore_v1/transaction.py @@ -23,6 +23,8 @@ from google.api_core import exceptions from google.cloud.firestore_v1 import batch from google.cloud.firestore_v1 import types +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.query import Query MAX_ATTEMPTS = 5 @@ -200,6 +202,37 @@ def _commit(self): self._clean_up() return list(commit_response.write_results) + def get_all(self, references): + """Retrieves multiple documents from Firestore. + + Args: + references (List[.DocumentReference, ...]): Iterable of document + references to be retrieved. + + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + return self._client.get_all(references, transaction=self._id) + + def get(self, ref_or_query): + """ + Retrieve a document or a query result from the database. + Args: + ref_or_query The document references or query object to return. + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + if isinstance(ref_or_query, DocumentReference): + return self._client.get_all([ref_or_query], transaction=self._id) + elif isinstance(ref_or_query, Query): + return ref_or_query.stream(self._id) + else: + raise ValueError( + 'Value for argument "refOrQuery" must be a DocumentReference or a Query.' + ) + class _Transactional(object): """Provide a callable object to use as a transactional decorater. diff --git a/firestore/tests/unit/v1/test_transaction.py b/firestore/tests/unit/v1/test_transaction.py index ed578ad3eea6..268f035345ac 100644 --- a/firestore/tests/unit/v1/test_transaction.py +++ b/firestore/tests/unit/v1/test_transaction.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest - +import datetime import mock @@ -329,6 +329,109 @@ def test__commit_failure(self): metadata=client._rpc_metadata, ) + def _get_all_helper(self, client, document_pbs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["batch_get_documents"]) + response_iterator = iter(document_pbs) + firestore_api.batch_get_documents.return_value = response_iterator + + # Attach the fake GAPIC to a real client. + client._firestore_api_internal = firestore_api + + def test_get_all(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + import types + + client = _make_client() + trans = self._make_one(client) + data1 = {"a": u"cheese"} + data2 = {"b": True, "c": 18} + document1 = client.document("pineapple", "lamp1") + document2 = client.document("pineapple", "lamp2") + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + document_pb2, read_time = _doc_get_info(document2._document_path, data2) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) + response2 = _make_batch_response(found=document_pb2, read_time=read_time) + + self._get_all_helper(client, [response1, response2]) + # Actually call get_all(). + snapshots = trans.get_all([document1, document2]) + self.assertIsInstance(snapshots, types.GeneratorType) + snapshots = list(snapshots) + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document1) + self.assertEqual(snapshot1._data, data1) + + snapshot2 = snapshots[1] + self.assertIsInstance(snapshot2, DocumentSnapshot) + self.assertIs(snapshot2._reference, document2) + self.assertEqual(snapshot2._data, data2) + + def test_get_document_ref(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + import types + + client = _make_client() + trans = self._make_one(client) + data1 = {"a": u"cheese"} + document1 = client.document("pineapple", "lamp1") + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) + self._get_all_helper(client, [response1]) + snapshots = trans.get(document1) + self.assertIsInstance(snapshots, types.GeneratorType) + snapshots = list(snapshots) + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document1) + self.assertEqual(snapshot1._data, data1) + + def test_get_query_ref(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.query import Query + import types + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + trans = self._make_one(client) + client._firestore_api_internal = firestore_api + parent = client.collection("declaration") + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Pass the query to get method and check the response. + query = Query(parent) + snapshots = trans.get(query) + self.assertIsInstance(snapshots, types.GeneratorType) + returned = list(snapshots) + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertEqual(snapshot.reference._path, ("declaration", "burger")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=trans._id, + metadata=client._rpc_metadata, + ) + + def test_get_failure(self): + client = _make_client() + trans = self._make_one(client) + ref_or_query = object() + with self.assertRaises(ValueError): + trans.get(ref_or_query) + class Test_Transactional(unittest.TestCase): @staticmethod @@ -983,3 +1086,56 @@ def _make_transaction(txn_id, **txn_kwargs): client._firestore_api_internal = firestore_api return Transaction(client, **txn_kwargs) + + +def _make_batch_response(**kwargs): + from google.cloud.firestore_v1.proto import firestore_pb2 + + return firestore_pb2.BatchGetDocumentsResponse(**kwargs) + + +def _doc_get_info(ref_string, values): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + + document_pb = document_pb2.Document( + name=ref_string, + fields=_helpers.encode_dict(values), + create_time=create_time, + update_time=update_time, + ) + + return document_pb, read_time + + +def _make_query_response(**kwargs): + # kwargs supported are ``skipped_results``, ``name`` and ``data`` + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + kwargs["read_time"] = read_time + + name = kwargs.pop("name", None) + data = kwargs.pop("data", None) + + document_pb = document_pb2.Document(name=name, fields=_helpers.encode_dict(data)) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + document_pb.update_time.CopyFrom(update_time) + document_pb.create_time.CopyFrom(create_time) + + kwargs["document"] = document_pb + + return firestore_pb2.RunQueryResponse(**kwargs) From 9b3594350be2a7bf61bdd19da70a05507db43b6d Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 3 Jul 2019 17:56:37 +0530 Subject: [PATCH 2/4] remane variables as per recommended --- .../google/cloud/firestore_v1/transaction.py | 2 +- firestore/tests/unit/v1/test_transaction.py | 46 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/firestore/google/cloud/firestore_v1/transaction.py b/firestore/google/cloud/firestore_v1/transaction.py index 64fcaf1ad2d5..6cdce52c5f3b 100644 --- a/firestore/google/cloud/firestore_v1/transaction.py +++ b/firestore/google/cloud/firestore_v1/transaction.py @@ -230,7 +230,7 @@ def get(self, ref_or_query): return ref_or_query.stream(self._id) else: raise ValueError( - 'Value for argument "refOrQuery" must be a DocumentReference or a Query.' + 'Value for argument "ref_or_query" must be a DocumentReference or a Query.' ) diff --git a/firestore/tests/unit/v1/test_transaction.py b/firestore/tests/unit/v1/test_transaction.py index 268f035345ac..ed2bd4146123 100644 --- a/firestore/tests/unit/v1/test_transaction.py +++ b/firestore/tests/unit/v1/test_transaction.py @@ -343,7 +343,7 @@ def test_get_all(self): import types client = _make_client() - trans = self._make_one(client) + transaction = self._make_one(client) data1 = {"a": u"cheese"} data2 = {"b": True, "c": 18} document1 = client.document("pineapple", "lamp1") @@ -355,37 +355,37 @@ def test_get_all(self): self._get_all_helper(client, [response1, response2]) # Actually call get_all(). - snapshots = trans.get_all([document1, document2]) + snapshots = transaction.get_all([document1, document2]) self.assertIsInstance(snapshots, types.GeneratorType) snapshots = list(snapshots) - snapshot1 = snapshots[0] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document1) - self.assertEqual(snapshot1._data, data1) + snapshot0 = snapshots[0] + self.assertIsInstance(snapshot0, DocumentSnapshot) + self.assertIs(snapshot0._reference, document1) + self.assertEqual(snapshot0._data, data1) - snapshot2 = snapshots[1] - self.assertIsInstance(snapshot2, DocumentSnapshot) - self.assertIs(snapshot2._reference, document2) - self.assertEqual(snapshot2._data, data2) + snapshot1 = snapshots[1] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document2) + self.assertEqual(snapshot1._data, data2) def test_get_document_ref(self): from google.cloud.firestore_v1.document import DocumentSnapshot import types client = _make_client() - trans = self._make_one(client) + transaction = self._make_one(client) data1 = {"a": u"cheese"} document1 = client.document("pineapple", "lamp1") document_pb1, read_time = _doc_get_info(document1._document_path, data1) response1 = _make_batch_response(found=document_pb1, read_time=read_time) self._get_all_helper(client, [response1]) - snapshots = trans.get(document1) + snapshots = transaction.get(document1) self.assertIsInstance(snapshots, types.GeneratorType) snapshots = list(snapshots) - snapshot1 = snapshots[0] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document1) - self.assertEqual(snapshot1._data, data1) + snapshot0 = snapshots[0] + self.assertIsInstance(snapshot0, DocumentSnapshot) + self.assertIs(snapshot0._reference, document1) + self.assertEqual(snapshot0._data, data1) def test_get_query_ref(self): from google.cloud.firestore_v1.document import DocumentSnapshot @@ -397,7 +397,7 @@ def test_get_query_ref(self): # Attach the fake GAPIC to a real client. client = _make_client() - trans = self._make_one(client) + transaction = self._make_one(client) client._firestore_api_internal = firestore_api parent = client.collection("declaration") parent_path, expected_prefix = parent._parent_info() @@ -408,20 +408,20 @@ def test_get_query_ref(self): # Pass the query to get method and check the response. query = Query(parent) - snapshots = trans.get(query) + snapshots = transaction.get(query) self.assertIsInstance(snapshots, types.GeneratorType) returned = list(snapshots) self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertEqual(snapshot.reference._path, ("declaration", "burger")) - self.assertEqual(snapshot.to_dict(), data) + snapshot0 = returned[0] + self.assertIsInstance(snapshot0, DocumentSnapshot) + self.assertEqual(snapshot0.reference._path, ("declaration", "burger")) + self.assertEqual(snapshot0.to_dict(), data) # Verify the mock call. firestore_api.run_query.assert_called_once_with( parent_path, query._to_protobuf(), - transaction=trans._id, + transaction=transaction._id, metadata=client._rpc_metadata, ) From 5e6f7e1ef6f5ad81f123b53ff304f7be09ebe588 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 9 Jul 2019 19:27:58 +0530 Subject: [PATCH 3/4] rename variable name --- firestore/tests/unit/v1/test_transaction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firestore/tests/unit/v1/test_transaction.py b/firestore/tests/unit/v1/test_transaction.py index ed2bd4146123..3a3874dc2b35 100644 --- a/firestore/tests/unit/v1/test_transaction.py +++ b/firestore/tests/unit/v1/test_transaction.py @@ -427,10 +427,10 @@ def test_get_query_ref(self): def test_get_failure(self): client = _make_client() - trans = self._make_one(client) + transaction = self._make_one(client) ref_or_query = object() with self.assertRaises(ValueError): - trans.get(ref_or_query) + transaction.get(ref_or_query) class Test_Transactional(unittest.TestCase): From 93595309fba3a254e1f2ac2ecd19a2291c4810e0 Mon Sep 17 00:00:00 2001 From: Hemang Date: Fri, 12 Jul 2019 11:55:28 +0530 Subject: [PATCH 4/4] update test cases as per recomonded --- .../google/cloud/firestore_v1/transaction.py | 2 +- firestore/tests/unit/v1/test_transaction.py | 159 +++--------------- 2 files changed, 20 insertions(+), 141 deletions(-) diff --git a/firestore/google/cloud/firestore_v1/transaction.py b/firestore/google/cloud/firestore_v1/transaction.py index 6cdce52c5f3b..9d4068c75a88 100644 --- a/firestore/google/cloud/firestore_v1/transaction.py +++ b/firestore/google/cloud/firestore_v1/transaction.py @@ -227,7 +227,7 @@ def get(self, ref_or_query): if isinstance(ref_or_query, DocumentReference): return self._client.get_all([ref_or_query], transaction=self._id) elif isinstance(ref_or_query, Query): - return ref_or_query.stream(self._id) + return ref_or_query.stream(transaction=self._id) else: raise ValueError( 'Value for argument "ref_or_query" must be a DocumentReference or a Query.' diff --git a/firestore/tests/unit/v1/test_transaction.py b/firestore/tests/unit/v1/test_transaction.py index 3a3874dc2b35..8cae24a23831 100644 --- a/firestore/tests/unit/v1/test_transaction.py +++ b/firestore/tests/unit/v1/test_transaction.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest -import datetime import mock @@ -329,101 +328,34 @@ def test__commit_failure(self): metadata=client._rpc_metadata, ) - def _get_all_helper(self, client, document_pbs): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["batch_get_documents"]) - response_iterator = iter(document_pbs) - firestore_api.batch_get_documents.return_value = response_iterator - - # Attach the fake GAPIC to a real client. - client._firestore_api_internal = firestore_api - def test_get_all(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - import types - - client = _make_client() + client = mock.Mock(spec=["get_all"]) transaction = self._make_one(client) - data1 = {"a": u"cheese"} - data2 = {"b": True, "c": 18} - document1 = client.document("pineapple", "lamp1") - document2 = client.document("pineapple", "lamp2") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - document_pb2, read_time = _doc_get_info(document2._document_path, data2) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) - response2 = _make_batch_response(found=document_pb2, read_time=read_time) - - self._get_all_helper(client, [response1, response2]) - # Actually call get_all(). - snapshots = transaction.get_all([document1, document2]) - self.assertIsInstance(snapshots, types.GeneratorType) - snapshots = list(snapshots) - snapshot0 = snapshots[0] - self.assertIsInstance(snapshot0, DocumentSnapshot) - self.assertIs(snapshot0._reference, document1) - self.assertEqual(snapshot0._data, data1) - - snapshot1 = snapshots[1] - self.assertIsInstance(snapshot1, DocumentSnapshot) - self.assertIs(snapshot1._reference, document2) - self.assertEqual(snapshot1._data, data2) + ref1, ref2 = mock.Mock(), mock.Mock() + result = transaction.get_all([ref1, ref2]) + client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction.id) + self.assertIs(result, client.get_all.return_value) def test_get_document_ref(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - import types + from google.cloud.firestore_v1.document import DocumentReference - client = _make_client() + client = mock.Mock(spec=["get_all"]) transaction = self._make_one(client) - data1 = {"a": u"cheese"} - document1 = client.document("pineapple", "lamp1") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) - self._get_all_helper(client, [response1]) - snapshots = transaction.get(document1) - self.assertIsInstance(snapshots, types.GeneratorType) - snapshots = list(snapshots) - snapshot0 = snapshots[0] - self.assertIsInstance(snapshot0, DocumentSnapshot) - self.assertIs(snapshot0._reference, document1) - self.assertEqual(snapshot0._data, data1) - - def test_get_query_ref(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - from google.cloud.firestore_v1.query import Query - import types + ref = DocumentReference("documents", "doc-id") + result = transaction.get(ref) + client.get_all.assert_called_once_with([ref], transaction=transaction.id) + self.assertIs(result, client.get_all.return_value) - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) + def test_get_w_query(self): + from google.cloud.firestore_v1.query import Query - # Attach the fake GAPIC to a real client. - client = _make_client() + client = mock.Mock(spec=[]) transaction = self._make_one(client) - client._firestore_api_internal = firestore_api - parent = client.collection("declaration") - parent_path, expected_prefix = parent._parent_info() - name = "{}/burger".format(expected_prefix) - data = {"lettuce": b"\xee\x87"} - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = iter([response_pb]) - - # Pass the query to get method and check the response. - query = Query(parent) - snapshots = transaction.get(query) - self.assertIsInstance(snapshots, types.GeneratorType) - returned = list(snapshots) - self.assertEqual(len(returned), 1) - snapshot0 = returned[0] - self.assertIsInstance(snapshot0, DocumentSnapshot) - self.assertEqual(snapshot0.reference._path, ("declaration", "burger")) - self.assertEqual(snapshot0.to_dict(), data) - - # Verify the mock call. - firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=transaction._id, - metadata=client._rpc_metadata, - ) + query = Query(parent=mock.Mock(spec=[])) + query.stream = mock.MagicMock() + result = transaction.get(query) + query.stream.assert_called_once_with(transaction=transaction.id) + self.assertIs(result, query.stream.return_value) def test_get_failure(self): client = _make_client() @@ -1086,56 +1018,3 @@ def _make_transaction(txn_id, **txn_kwargs): client._firestore_api_internal = firestore_api return Transaction(client, **txn_kwargs) - - -def _make_batch_response(**kwargs): - from google.cloud.firestore_v1.proto import firestore_pb2 - - return firestore_pb2.BatchGetDocumentsResponse(**kwargs) - - -def _doc_get_info(ref_string, values): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1 import _helpers - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - - document_pb = document_pb2.Document( - name=ref_string, - fields=_helpers.encode_dict(values), - create_time=create_time, - update_time=update_time, - ) - - return document_pb, read_time - - -def _make_query_response(**kwargs): - # kwargs supported are ``skipped_results``, ``name`` and ``data`` - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1 import _helpers - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - kwargs["read_time"] = read_time - - name = kwargs.pop("name", None) - data = kwargs.pop("data", None) - - document_pb = document_pb2.Document(name=name, fields=_helpers.encode_dict(data)) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - document_pb.update_time.CopyFrom(update_time) - document_pb.create_time.CopyFrom(create_time) - - kwargs["document"] = document_pb - - return firestore_pb2.RunQueryResponse(**kwargs)