Skip to content

Commit

Permalink
Reject invalid paths passed to 'Query.{select,where,order_by}' (#6770)
Browse files Browse the repository at this point in the history
Closes #6736.
  • Loading branch information
tseaver authored Nov 30, 2018
1 parent b41293d commit 92a907a
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 72 deletions.
112 changes: 69 additions & 43 deletions firestore/google/cloud/firestore_v1beta1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ def decode_dict(value_fields, client):
}


SIMPLE_FIELD_NAME = re.compile('^[_a-zA-Z][_a-zA-Z0-9]*$')


def get_field_path(field_names):
"""Create a **field path** from a list of nested field names.
Expand All @@ -468,11 +471,10 @@ def get_field_path(field_names):
Returns:
str: The ``.``-delimited field path.
"""
simple_field_name = re.compile('^[_a-zA-Z][_a-zA-Z0-9]*$')
result = []

for field_name in field_names:
match = re.match(simple_field_name, field_name)
match = SIMPLE_FIELD_NAME.match(field_name)
if match and match.group(0) == field_name:
result.append(field_name)
else:
Expand All @@ -482,6 +484,70 @@ def get_field_path(field_names):
return FIELD_PATH_DELIMITER.join(result)


PATH_ELEMENT_TOKENS = [
('SIMPLE', r'[_a-zA-Z][_a-zA-Z0-9]*'), # unquoted elements
('QUOTED', r'`(?:\\`|[^`])*?`'), # quoted elements, unquoted
('DOT', r'\.'), # separator
]
TOKENS_PATTERN = '|'.join(
'(?P<{}>{})'.format(*pair) for pair in PATH_ELEMENT_TOKENS)
TOKENS_REGEX = re.compile(TOKENS_PATTERN)


def _tokenize_field_path(path):
"""Lex a field path into tokens (including dots).
Args:
path (str): field path to be lexed.
Returns:
List(str): tokens
"""
pos = 0
get_token = TOKENS_REGEX.match
match = get_token(path)
while match is not None:
type_ = match.lastgroup
value = match.group(type_)
yield value
pos = match.end()
match = get_token(path, pos)


def split_field_path(path):
"""Split a field path into valid elements (without dots).
Args:
path (str): field path to be lexed.
Returns:
List(str): tokens
Raises:
ValueError: if the path does not match the elements-interspersed-
with-dots pattern.
"""
if not path:
return []

elements = []
want_dot = False

for element in _tokenize_field_path(path):
if want_dot:
if element != '.':
raise ValueError("Invalid path: {}".format(path))
else:
want_dot = False
else:
if element == '.':
raise ValueError("Invalid path: {}".format(path))
elements.append(element)
want_dot = True

if not want_dot or not elements:
raise ValueError("Invalid path: {}".format(path))

return elements


def parse_field_path(api_repr):
"""Parse a **field path** from into a list of nested field names.
Expand All @@ -501,8 +567,7 @@ def parse_field_path(api_repr):
# code dredged back up from
# https://github.com/googleapis/google-cloud-python/pull/5109/files
field_names = []
while api_repr:
field_name, api_repr = _parse_field_name(api_repr)
for field_name in split_field_path(api_repr):
# non-simple field name
if field_name[0] == '`' and field_name[-1] == '`':
field_name = field_name[1:-1]
Expand All @@ -512,45 +577,6 @@ def parse_field_path(api_repr):
return field_names


def _parse_field_name(api_repr):
"""
Parses the api_repr into the first field name and the rest
Args:
api_repr (str): The unique Firestore api representation.
Returns:
Tuple[str, str]:
A tuple with the first field name and the api_repr
of the rest.
"""
# XXX code dredged back up from
# https://github.com/googleapis/google-cloud-python/pull/5109/files;
# probably needs some speeding up

if '.' not in api_repr:
return api_repr, None

if api_repr[0] != '`': # first field name is simple
index = api_repr.index('.')
return api_repr[:index], api_repr[index+1:] # skips delimiter

# starts with backtick: find next non-escaped backtick.
index = 1
while index < len(api_repr):

if api_repr[index] == '`': # end of quoted field name
break

if api_repr[index] == '\\': # escape character, skip next
index += 2
else:
index += 1

if index == len(api_repr): # no closing backtick found
raise ValueError("No closing backtick: {}".format(api_repr))

return api_repr[:index+1], api_repr[index+2:]


def get_nested_value(field_path, data):
"""Get a (potentially nested) value from a dictionary.
Expand Down
12 changes: 12 additions & 0 deletions firestore/google/cloud/firestore_v1beta1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,13 @@ def select(self, field_paths):
~.firestore_v1beta1.query.Query: A "projected" query. Acts as
a copy of the current query, modified with the newly added
projection.
Raises:
ValueError: If any ``field_path`` is invalid.
"""
field_paths = list(field_paths)
for field_path in field_paths:
_helpers.split_field_path(field_path) # raises

new_projection = query_pb2.StructuredQuery.Projection(
fields=[
query_pb2.StructuredQuery.FieldReference(field_path=field_path)
Expand Down Expand Up @@ -204,9 +210,12 @@ def where(self, field_path, op_string, value):
copy of the current query, modified with the newly added filter.
Raises:
ValueError: If ``field_path`` is invalid.
ValueError: If ``value`` is a NaN or :data:`None` and
``op_string`` is not ``==``.
"""
_helpers.split_field_path(field_path) # raises

if value is None:
if op_string != _EQ_OP:
raise ValueError(_BAD_OP_NAN_NULL)
Expand Down Expand Up @@ -269,9 +278,12 @@ def order_by(self, field_path, direction=ASCENDING):
"order by" constraint.
Raises:
ValueError: If ``field_path`` is invalid.
ValueError: If ``direction`` is not one of :attr:`ASCENDING` or
:attr:`DESCENDING`.
"""
_helpers.split_field_path(field_path) # raises

order_pb = query_pb2.StructuredQuery.Order(
field=query_pb2.StructuredQuery.FieldReference(
field_path=field_path,
Expand Down
109 changes: 80 additions & 29 deletions firestore/tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,86 @@ def test_multiple(self):
self.assertEqual(self._call_fut(['a', 'b', 'c']), 'a.b.c')


class Test__tokenize_field_path(unittest.TestCase):

@staticmethod
def _call_fut(path):
from google.cloud.firestore_v1beta1 import _helpers

return _helpers._tokenize_field_path(path)

def _expect(self, path, split_path):
self.assertEqual(list(self._call_fut(path)), split_path)

def test_w_empty(self):
self._expect('', [])

def test_w_single_dot(self):
self._expect('.', ['.'])

def test_w_single_simple(self):
self._expect('abc', ['abc'])

def test_w_single_quoted(self):
self._expect('`c*de`', ['`c*de`'])

def test_w_quoted_embedded_dot(self):
self._expect('`c*.de`', ['`c*.de`'])

def test_w_quoted_escaped_backtick(self):
self._expect(r'`c*\`de`', [r'`c*\`de`'])

def test_w_dotted_quoted(self):
self._expect('`*`.`~`', ['`*`', '.', '`~`'])

def test_w_dotted(self):
self._expect('a.b.`c*de`', ['a', '.', 'b', '.', '`c*de`'])


class Test_split_field_path(unittest.TestCase):

@staticmethod
def _call_fut(path):
from google.cloud.firestore_v1beta1 import _helpers

return _helpers.split_field_path(path)

def test_w_single_dot(self):
with self.assertRaises(ValueError):
self._call_fut('.')

def test_w_leading_dot(self):
with self.assertRaises(ValueError):
self._call_fut('.a.b.c')

def test_w_trailing_dot(self):
with self.assertRaises(ValueError):
self._call_fut('a.b.')

def test_w_missing_dot(self):
with self.assertRaises(ValueError):
self._call_fut('a`c*de`f')

def test_w_half_quoted_field(self):
with self.assertRaises(ValueError):
self._call_fut('`c*de')

def test_w_empty(self):
self.assertEqual(self._call_fut(''), [])

def test_w_simple_field(self):
self.assertEqual(self._call_fut('a'), ['a'])

def test_w_dotted_field(self):
self.assertEqual(self._call_fut('a.b.cde'), ['a', 'b', 'cde'])

def test_w_quoted_field(self):
self.assertEqual(self._call_fut('a.b.`c*de`'), ['a', 'b', '`c*de`'])

def test_w_quoted_field_escaped_backtick(self):
self.assertEqual(self._call_fut(r'`c*\`de`'), [r'`c*\`de`'])


class Test_parse_field_path(unittest.TestCase):

@staticmethod
Expand All @@ -880,35 +960,6 @@ def test_w_escaped_backtick(self):
def test_w_escaped_backslash(self):
self.assertEqual(self._call_fut('`a\\\\b`.c.d'), ['a\\b', 'c', 'd'])


class Test__parse_field_name(unittest.TestCase):

@staticmethod
def _call_fut(field_path):
from google.cloud.firestore_v1beta1._helpers import _parse_field_name

return _parse_field_name(field_path)

def test_w_no_dots(self):
name, rest = self._call_fut('a')
self.assertEqual(name, 'a')
self.assertIsNone(rest)

def test_w_first_name_simple(self):
name, rest = self._call_fut('a.b.c')
self.assertEqual(name, 'a')
self.assertEqual(rest, 'b.c')

def test_w_first_name_escaped_no_escapse(self):
name, rest = self._call_fut('`3`.b.c')
self.assertEqual(name, '`3`')
self.assertEqual(rest, 'b.c')

def test_w_first_name_escaped_w_escaped_backtick(self):
name, rest = self._call_fut('`a\\`b`.c.d')
self.assertEqual(name, '`a\\`b`')
self.assertEqual(rest, 'c.d')

def test_w_first_name_escaped_wo_closing_backtick(self):
with self.assertRaises(ValueError):
self._call_fut('`a\\`b.c.d')
Expand Down
18 changes: 18 additions & 0 deletions firestore/tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def _make_projection_for_select(field_paths):
],
)

def test_select_invalid_path(self):
query = self._make_one(mock.sentinel.parent)

with self.assertRaises(ValueError):
query.select(['*'])

def test_select(self):
query1 = self._make_one_all_fields()

Expand All @@ -123,6 +129,12 @@ def test_select(self):
self._make_projection_for_select(field_paths3))
self._compare_queries(query2, query3, '_projection')

def test_where_invalid_path(self):
query = self._make_one(mock.sentinel.parent)

with self.assertRaises(ValueError):
query.where('*', '==', 1)

def test_where(self):
from google.cloud.firestore_v1beta1.gapic import enums
from google.cloud.firestore_v1beta1.proto import document_pb2
Expand Down Expand Up @@ -187,6 +199,12 @@ def test_where_le_nan(self):
with self.assertRaises(ValueError):
self._where_unary_helper(float('nan'), 0, op_string='<=')

def test_order_by_invalid_path(self):
query = self._make_one(mock.sentinel.parent)

with self.assertRaises(ValueError):
query.order_by('*')

def test_order_by(self):
from google.cloud.firestore_v1beta1.gapic import enums

Expand Down

0 comments on commit 92a907a

Please sign in to comment.