Skip to content

Commit 65b1a7c

Browse files
authored
Escape \ and " in typeahead query (#1306)
1 parent 4624829 commit 65b1a7c

File tree

3 files changed

+160
-4
lines changed

3 files changed

+160
-4
lines changed

src/marqo/core/typeahead/typeahead.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import blake3
66

7-
from marqo.core.constants import MARQO_TYPEAHEAD_SCHEMA_MINIMUM_VERSION
7+
from marqo.core.constants import MARQO_TYPEAHEAD_SCHEMA_MINIMUM_VERSION, CHARACTERS_TO_BE_ESCAPED_IN_VESPA
88
from marqo.core.index_management.index_management import IndexManagement
99
from marqo.core.models.typeahead import (
1010
TypeaheadRequest, TypeaheadResponse, TypeaheadSuggestion,
@@ -66,19 +66,20 @@ def get_suggestions(self, index_name: str, request: TypeaheadRequest) -> Typeahe
6666
retrieval_terms = []
6767
ranking_terms = []
6868
for token in tokens:
69+
escaped_token = self._escape_token(token)
6970
if len(token) < request.min_fuzzy_match_length:
7071
# Use exact prefix matching for short tokens
7172
retrieval_terms.append(
72-
f"query_words contains ({{prefix:true}}\"{token}\")"
73+
f"query_words contains ({{prefix:true}}\"{escaped_token}\")"
7374
)
7475
else:
7576
# Use fuzzy matching for longer tokens
7677
retrieval_terms.append(
7778
f"query_words contains "
78-
f"({{maxEditDistance:{request.fuzzy_edit_distance}, prefix:true}}fuzzy(\"{token}\"))"
79+
f"({{maxEditDistance:{request.fuzzy_edit_distance}, prefix:true}}fuzzy(\"{escaped_token}\"))"
7980
)
8081

81-
ranking_terms.append(f"query_index contains \"{token}\"")
82+
ranking_terms.append(f"query_index contains \"{escaped_token}\"")
8283

8384
# Create single YQL query that ORs all token conditions
8485
yql_retrieval = " OR ".join(retrieval_terms)
@@ -309,6 +310,23 @@ def get_queries(self, index_name: str, queries: List[str]) -> TypeaheadGetQuerie
309310

310311
return TypeaheadGetQueriesResponse(queries=query_results)
311312

313+
def _escape_token(self, token: str) -> str:
314+
"""Escape special characters in a token for Vespa YQL queries.
315+
316+
Args:
317+
token: The token to escape
318+
319+
Returns:
320+
The escaped token
321+
"""
322+
escaped = []
323+
for char in token:
324+
if char in CHARACTERS_TO_BE_ESCAPED_IN_VESPA:
325+
escaped.append('\\' + char)
326+
else:
327+
escaped.append(char)
328+
return ''.join(escaped)
329+
312330
def _generate_query_hash(self, query: str) -> str:
313331
"""Generate a 128-bit blake3 hash for a query string.
314332

tests/integ_tests/core/typeahead/test_typeahead_integration.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,47 @@ def test_typeahead_version_check_with_old_version_index(self):
456456

457457
self._assert_version_error_message(context.exception, self.index_220_name, "2.22.0")
458458

459+
# G. Special Character Handling Tests
460+
def test_typeahead_with_special_characters_in_user_input(self):
461+
"""Test typeahead search handles special characters in user input without causing Vespa 500 errors."""
462+
self._index_test_queries()
463+
464+
# Test user input queries with special characters that should NOT cause Vespa 500 errors
465+
user_input_test_cases = [
466+
# Single special characters
467+
'"',
468+
'\\',
469+
# Queries with quotes
470+
'a"b"c',
471+
'a"b"',
472+
'"b"c',
473+
'"bc',
474+
'bc"',
475+
# Queries with backslashes
476+
'Path\\to\\file',
477+
'\\',
478+
'\\\\',
479+
'a\\b',
480+
'\\a',
481+
'b\\',
482+
# Mixed special characters
483+
'Program "with spaces"\\folder',
484+
'"\\',
485+
'\\"',
486+
]
487+
488+
for user_query in user_input_test_cases:
489+
with self.subTest(user_query=user_query):
490+
request = TypeaheadRequest(q=user_query, limit=10)
491+
# This should NOT raise a 500 error from Vespa (main goal of the fix)
492+
try:
493+
response = self.config.typeahead.get_suggestions(self.test_index_name, request)
494+
# Just verify we get a response without error
495+
self.assertIsNotNone(response)
496+
self.assertGreaterEqual(len(response.suggestions), 0)
497+
except Exception as e:
498+
self.fail(f"User input query '{user_query}' caused an error: {e}")
499+
459500
def _assert_version_error_message(self, exception: UnsupportedFeatureError, index_name: str, version: str):
460501
"""Helper method to verify the error message contains expected information."""
461502
error_message = str(exception)

tests/unit_tests/marqo/core/typeahead/test_typeahead.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,5 +782,102 @@ def test_get_queries_mixed_found_not_found(self):
782782
self.assertEqual(result.queries[0].query, "found query")
783783

784784

785+
class TestTypeaheadEscaping(unittest.TestCase):
786+
"""Test cases for the special character escaping functionality."""
787+
788+
def setUp(self):
789+
"""Set up test fixtures."""
790+
self.mock_vespa_client = Mock(spec=VespaClient)
791+
self.mock_index_management = Mock(spec=IndexManagement)
792+
self.typeahead = Typeahead(
793+
vespa_client=self.mock_vespa_client,
794+
index_management=self.mock_index_management
795+
)
796+
797+
def test_escape_token_no_special_chars(self):
798+
"""Test escaping tokens with no special characters."""
799+
# Simple strings should not be modified
800+
test_cases = [
801+
"hello",
802+
"world",
803+
"123",
804+
"test_query",
805+
"simple-text",
806+
"with spaces",
807+
"",
808+
]
809+
810+
for token in test_cases:
811+
with self.subTest(token=token):
812+
result = self.typeahead._escape_token(token)
813+
self.assertEqual(result, token)
814+
815+
def test_escape_token_quotes(self):
816+
"""Test escaping tokens with double quotes."""
817+
test_cases = [
818+
('"', '\\"'),
819+
('hello"world', 'hello\\"world'),
820+
('"start', '\\"start'),
821+
('end"', 'end\\"'),
822+
('""', '\\"\\"'),
823+
('say "hello"', 'say \\"hello\\"'),
824+
]
825+
826+
for input_token, expected in test_cases:
827+
with self.subTest(input_token=input_token):
828+
result = self.typeahead._escape_token(input_token)
829+
self.assertEqual(result, expected)
830+
831+
def test_escape_token_backslashes(self):
832+
"""Test escaping tokens with backslashes."""
833+
test_cases = [
834+
('\\', '\\\\'),
835+
('hello\\world', 'hello\\\\world'),
836+
('\\start', '\\\\start'),
837+
('end\\', 'end\\\\'),
838+
('\\\\', '\\\\\\\\'),
839+
('path\\to\\file', 'path\\\\to\\\\file'),
840+
]
841+
842+
for input_token, expected in test_cases:
843+
with self.subTest(input_token=input_token):
844+
result = self.typeahead._escape_token(input_token)
845+
self.assertEqual(result, expected)
846+
847+
def test_escape_token_mixed_special_chars(self):
848+
"""Test escaping tokens with both quotes and backslashes."""
849+
test_cases = [
850+
('"\\', '\\"\\\\'),
851+
('"hello\\world"', '\\"hello\\\\world\\"'),
852+
('C:\\Program Files\\"test"', 'C:\\\\Program Files\\\\\\"test\\"'),
853+
('a"b\\c"d', 'a\\"b\\\\c\\"d'),
854+
]
855+
856+
for input_token, expected in test_cases:
857+
with self.subTest(input_token=input_token):
858+
result = self.typeahead._escape_token(input_token)
859+
self.assertEqual(result, expected)
860+
861+
def test_escape_token_preserves_other_chars(self):
862+
"""Test that escaping preserves other special characters."""
863+
# Characters that are NOT in CHARACTERS_TO_BE_ESCAPED_IN_VESPA should be preserved
864+
test_cases = [
865+
"hello!world",
866+
867+
"price$100",
868+
"50%off",
869+
"a+b=c",
870+
"question?",
871+
"array[0]",
872+
"function()",
873+
"hash#tag",
874+
]
875+
876+
for token in test_cases:
877+
with self.subTest(token=token):
878+
result = self.typeahead._escape_token(token)
879+
self.assertEqual(result, token)
880+
881+
785882
if __name__ == "__main__":
786883
unittest.main()

0 commit comments

Comments
 (0)