Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Firestore: reject invalid paths passed to 'Query.{select,where,order_by}' #6770

Merged
merged 2 commits into from
Nov 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 69 additions & 43 deletions firestore/google/cloud/firestore_v1beta1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ def decode_dict(value_fields, client):
}


SIMPLE_FIELD_NAME = re.compile('^[_a-zA-Z][_a-zA-Z0-9]*$')


def get_field_path(field_names):
"""Create a **field path** from a list of nested field names.

Expand All @@ -468,11 +471,10 @@ def get_field_path(field_names):
Returns:
str: The ``.``-delimited field path.
"""
simple_field_name = re.compile('^[_a-zA-Z][_a-zA-Z0-9]*$')
result = []

for field_name in field_names:
match = re.match(simple_field_name, field_name)
match = SIMPLE_FIELD_NAME.match(field_name)
if match and match.group(0) == field_name:
result.append(field_name)
else:
Expand All @@ -482,6 +484,70 @@ def get_field_path(field_names):
return FIELD_PATH_DELIMITER.join(result)


PATH_ELEMENT_TOKENS = [
('SIMPLE', r'[_a-zA-Z][_a-zA-Z0-9]*'), # unquoted elements

This comment was marked as spam.

This comment was marked as spam.

('QUOTED', r'`(?:\\`|[^`])*?`'), # quoted elements, unquoted
('DOT', r'\.'), # separator
]
TOKENS_PATTERN = '|'.join(
'(?P<{}>{})'.format(*pair) for pair in PATH_ELEMENT_TOKENS)
TOKENS_REGEX = re.compile(TOKENS_PATTERN)


def _tokenize_field_path(path):
"""Lex a field path into tokens (including dots).

Args:
path (str): field path to be lexed.
Returns:
List(str): tokens
"""
pos = 0
get_token = TOKENS_REGEX.match
match = get_token(path)
while match is not None:
type_ = match.lastgroup
value = match.group(type_)
yield value
pos = match.end()
match = get_token(path, pos)


def split_field_path(path):
"""Split a field path into valid elements (without dots).

Args:
path (str): field path to be lexed.
Returns:
List(str): tokens
Raises:
ValueError: if the path does not match the elements-interspersed-
with-dots pattern.
"""
if not path:
return []

elements = []
want_dot = False

for element in _tokenize_field_path(path):
if want_dot:
if element != '.':
raise ValueError("Invalid path: {}".format(path))
else:
want_dot = False
else:
if element == '.':
raise ValueError("Invalid path: {}".format(path))
elements.append(element)
want_dot = True

if not want_dot or not elements:
raise ValueError("Invalid path: {}".format(path))

return elements


def parse_field_path(api_repr):
"""Parse a **field path** from into a list of nested field names.

Expand All @@ -501,8 +567,7 @@ def parse_field_path(api_repr):
# code dredged back up from
# https://github.com/googleapis/google-cloud-python/pull/5109/files
field_names = []
while api_repr:
field_name, api_repr = _parse_field_name(api_repr)
for field_name in split_field_path(api_repr):
# non-simple field name
if field_name[0] == '`' and field_name[-1] == '`':
field_name = field_name[1:-1]
Expand All @@ -512,45 +577,6 @@ def parse_field_path(api_repr):
return field_names


def _parse_field_name(api_repr):
"""
Parses the api_repr into the first field name and the rest
Args:
api_repr (str): The unique Firestore api representation.
Returns:
Tuple[str, str]:
A tuple with the first field name and the api_repr
of the rest.
"""
# XXX code dredged back up from
# https://github.com/googleapis/google-cloud-python/pull/5109/files;
# probably needs some speeding up

if '.' not in api_repr:
return api_repr, None

if api_repr[0] != '`': # first field name is simple
index = api_repr.index('.')
return api_repr[:index], api_repr[index+1:] # skips delimiter

# starts with backtick: find next non-escaped backtick.
index = 1
while index < len(api_repr):

if api_repr[index] == '`': # end of quoted field name
break

if api_repr[index] == '\\': # escape character, skip next
index += 2
else:
index += 1

if index == len(api_repr): # no closing backtick found
raise ValueError("No closing backtick: {}".format(api_repr))

return api_repr[:index+1], api_repr[index+2:]


def get_nested_value(field_path, data):
"""Get a (potentially nested) value from a dictionary.

Expand Down
12 changes: 12 additions & 0 deletions firestore/google/cloud/firestore_v1beta1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,13 @@ def select(self, field_paths):
~.firestore_v1beta1.query.Query: A "projected" query. Acts as
a copy of the current query, modified with the newly added
projection.
Raises:
ValueError: If any ``field_path`` is invalid.
"""
field_paths = list(field_paths)
for field_path in field_paths:
_helpers.split_field_path(field_path) # raises

new_projection = query_pb2.StructuredQuery.Projection(
fields=[
query_pb2.StructuredQuery.FieldReference(field_path=field_path)
Expand Down Expand Up @@ -204,9 +210,12 @@ def where(self, field_path, op_string, value):
copy of the current query, modified with the newly added filter.

Raises:
ValueError: If ``field_path`` is invalid.
ValueError: If ``value`` is a NaN or :data:`None` and
``op_string`` is not ``==``.
"""
_helpers.split_field_path(field_path) # raises

if value is None:
if op_string != _EQ_OP:
raise ValueError(_BAD_OP_NAN_NULL)
Expand Down Expand Up @@ -269,9 +278,12 @@ def order_by(self, field_path, direction=ASCENDING):
"order by" constraint.

Raises:
ValueError: If ``field_path`` is invalid.
ValueError: If ``direction`` is not one of :attr:`ASCENDING` or
:attr:`DESCENDING`.
"""
_helpers.split_field_path(field_path) # raises

order_pb = query_pb2.StructuredQuery.Order(
field=query_pb2.StructuredQuery.FieldReference(
field_path=field_path,
Expand Down
109 changes: 80 additions & 29 deletions firestore/tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,86 @@ def test_multiple(self):
self.assertEqual(self._call_fut(['a', 'b', 'c']), 'a.b.c')


class Test__tokenize_field_path(unittest.TestCase):

@staticmethod
def _call_fut(path):
from google.cloud.firestore_v1beta1 import _helpers

return _helpers._tokenize_field_path(path)

def _expect(self, path, split_path):
self.assertEqual(list(self._call_fut(path)), split_path)

def test_w_empty(self):
self._expect('', [])

def test_w_single_dot(self):
self._expect('.', ['.'])

def test_w_single_simple(self):
self._expect('abc', ['abc'])

def test_w_single_quoted(self):
self._expect('`c*de`', ['`c*de`'])

def test_w_quoted_embedded_dot(self):
self._expect('`c*.de`', ['`c*.de`'])

def test_w_quoted_escaped_backtick(self):
self._expect(r'`c*\`de`', [r'`c*\`de`'])

def test_w_dotted_quoted(self):
self._expect('`*`.`~`', ['`*`', '.', '`~`'])

def test_w_dotted(self):
self._expect('a.b.`c*de`', ['a', '.', 'b', '.', '`c*de`'])


class Test_split_field_path(unittest.TestCase):

@staticmethod
def _call_fut(path):
from google.cloud.firestore_v1beta1 import _helpers

return _helpers.split_field_path(path)

def test_w_single_dot(self):
with self.assertRaises(ValueError):
self._call_fut('.')

def test_w_leading_dot(self):
with self.assertRaises(ValueError):
self._call_fut('.a.b.c')

def test_w_trailing_dot(self):
with self.assertRaises(ValueError):
self._call_fut('a.b.')

def test_w_missing_dot(self):
with self.assertRaises(ValueError):
self._call_fut('a`c*de`f')

def test_w_half_quoted_field(self):
with self.assertRaises(ValueError):
self._call_fut('`c*de')

def test_w_empty(self):
self.assertEqual(self._call_fut(''), [])

def test_w_simple_field(self):
self.assertEqual(self._call_fut('a'), ['a'])

def test_w_dotted_field(self):
self.assertEqual(self._call_fut('a.b.cde'), ['a', 'b', 'cde'])

def test_w_quoted_field(self):
self.assertEqual(self._call_fut('a.b.`c*de`'), ['a', 'b', '`c*de`'])

def test_w_quoted_field_escaped_backtick(self):
self.assertEqual(self._call_fut(r'`c*\`de`'), [r'`c*\`de`'])


class Test_parse_field_path(unittest.TestCase):

@staticmethod
Expand All @@ -880,35 +960,6 @@ def test_w_escaped_backtick(self):
def test_w_escaped_backslash(self):
self.assertEqual(self._call_fut('`a\\\\b`.c.d'), ['a\\b', 'c', 'd'])


class Test__parse_field_name(unittest.TestCase):

@staticmethod
def _call_fut(field_path):
from google.cloud.firestore_v1beta1._helpers import _parse_field_name

return _parse_field_name(field_path)

def test_w_no_dots(self):
name, rest = self._call_fut('a')
self.assertEqual(name, 'a')
self.assertIsNone(rest)

def test_w_first_name_simple(self):
name, rest = self._call_fut('a.b.c')
self.assertEqual(name, 'a')
self.assertEqual(rest, 'b.c')

def test_w_first_name_escaped_no_escapse(self):
name, rest = self._call_fut('`3`.b.c')
self.assertEqual(name, '`3`')
self.assertEqual(rest, 'b.c')

def test_w_first_name_escaped_w_escaped_backtick(self):
name, rest = self._call_fut('`a\\`b`.c.d')
self.assertEqual(name, '`a\\`b`')
self.assertEqual(rest, 'c.d')

def test_w_first_name_escaped_wo_closing_backtick(self):
with self.assertRaises(ValueError):
self._call_fut('`a\\`b.c.d')
Expand Down
18 changes: 18 additions & 0 deletions firestore/tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def _make_projection_for_select(field_paths):
],
)

def test_select_invalid_path(self):
query = self._make_one(mock.sentinel.parent)

with self.assertRaises(ValueError):
query.select(['*'])

def test_select(self):
query1 = self._make_one_all_fields()

Expand All @@ -123,6 +129,12 @@ def test_select(self):
self._make_projection_for_select(field_paths3))
self._compare_queries(query2, query3, '_projection')

def test_where_invalid_path(self):
query = self._make_one(mock.sentinel.parent)

with self.assertRaises(ValueError):
query.where('*', '==', 1)

def test_where(self):
from google.cloud.firestore_v1beta1.gapic import enums
from google.cloud.firestore_v1beta1.proto import document_pb2
Expand Down Expand Up @@ -187,6 +199,12 @@ def test_where_le_nan(self):
with self.assertRaises(ValueError):
self._where_unary_helper(float('nan'), 0, op_string='<=')

def test_order_by_invalid_path(self):
query = self._make_one(mock.sentinel.parent)

with self.assertRaises(ValueError):
query.order_by('*')

def test_order_by(self):
from google.cloud.firestore_v1beta1.gapic import enums

Expand Down