diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 9f4ed98..fa6b965 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,7 +10,7 @@ jobs: uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 - name: Set up pdoc run: pip install pdoc3 diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index 448b30b..45f9807 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -32,7 +32,7 @@ jobs: fetch-depth: 0 - name: Set up Python 3.8 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: 3.8 diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml index 0737d8a..a2e1608 100644 --- a/.github/workflows/publish-to-test-pypi.yml +++ b/.github/workflows/publish-to-test-pypi.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/checkout@v3 - name: Set up Python 3.8 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: 3.8 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6bdec31..a173482 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} cache: 'pip' diff --git a/src/amplitude_experiment/remote/client.py b/src/amplitude_experiment/remote/client.py index 152a98d..52baece 100644 --- a/src/amplitude_experiment/remote/client.py +++ b/src/amplitude_experiment/remote/client.py @@ -5,6 +5,7 @@ from typing import Any, Dict from .config import RemoteEvaluationConfig +from .fetch_options import FetchOptions from ..connection_pool import HTTPConnectionPool from ..exception import FetchException from ..user import User @@ -34,7 +35,7 @@ def __init__(self, api_key, config=None): self.logger = self.config.logger self.__setup_connection_pool() - def fetch_v2(self, user: User): + def fetch_v2(self, user: User, fetch_options: FetchOptions = None): """ Fetch all variants for a user synchronously. This method will automatically retry if configured, and throw if all retries fail. This function differs from fetch as it will return a default variant object if the flag @@ -42,12 +43,13 @@ def fetch_v2(self, user: User): Parameters: user (User): The Experiment User to fetch variants for. + fetch_options (FetchOptions): The Fetch Options Returns: Variants Dictionary. """ try: - return self.__fetch_internal(user) + return self.__fetch_internal(user, fetch_options) except Exception as e: self.logger.error(f"[Experiment] Failed to fetch variants: {e}") raise e @@ -63,17 +65,18 @@ def fetch_async_v2(self, user: User, callback=None): thread.start() @deprecated("Use fetch_v2") - def fetch(self, user: User): + def fetch(self, user: User, fetch_options: FetchOptions = None): """ Fetch all variants for a user synchronous. This method will automatically retry if configured. Parameters: user (User): The Experiment User + fetch_options (FetchOptions): The Fetch Options Returns: Variants Dictionary. """ try: - variants = self.fetch_v2(user) + variants = self.fetch_v2(user, fetch_options) return self.__filter_default_variants(variants) except Exception: return {} @@ -103,16 +106,16 @@ def __fetch_async_internal(self, user, callback): callback(user, {}, e) return {} - def __fetch_internal(self, user): + def __fetch_internal(self, user, fetch_options: FetchOptions = None): self.logger.debug(f"[Experiment] Fetching variants for user: {user}") try: - return self.__do_fetch(user) + return self.__do_fetch(user, fetch_options) except Exception as e: self.logger.error(f"[Experiment] Fetch failed: {e}") if self.__should_retry_fetch(e): - return self.__retry_fetch(user) + return self.__retry_fetch(user, fetch_options) - def __retry_fetch(self, user): + def __retry_fetch(self, user, fetch_options: FetchOptions = None): if self.config.fetch_retries == 0: return {} self.logger.debug("[Experiment] Retrying fetch") @@ -121,7 +124,7 @@ def __retry_fetch(self, user): for i in range(self.config.fetch_retries): sleep(delay_millis / 1000.0) try: - return self.__do_fetch(user) + return self.__do_fetch(user, fetch_options) except Exception as e: self.logger.error(f"[Experiment] Retry failed: {e}") err = e @@ -129,13 +132,18 @@ def __retry_fetch(self, user): self.config.fetch_retry_backoff_max_millis) raise err - def __do_fetch(self, user): + def __do_fetch(self, user, fetch_options: FetchOptions = None): start = time.time() user_context = self.__add_context(user) headers = { 'Authorization': f"Api-Key {self.api_key}", 'Content-Type': 'application/json;charset=utf-8' } + if fetch_options and fetch_options.tracksAssignment is not None: + headers['X-Amp-Exp-Track'] = "track" if fetch_options.tracksAssignment else "no-track" + if fetch_options and fetch_options.tracksExposure is not None: + headers['X-Amp-Exp-Exposure-Track'] = "track" if fetch_options.tracksExposure else "no-track" + conn = self._connection_pool.acquire() body = user_context.to_json().encode('utf8') if len(body) > 8000: diff --git a/src/amplitude_experiment/remote/fetch_options.py b/src/amplitude_experiment/remote/fetch_options.py new file mode 100644 index 0000000..8538888 --- /dev/null +++ b/src/amplitude_experiment/remote/fetch_options.py @@ -0,0 +1,14 @@ +from typing import Optional +class FetchOptions: + def __init__(self, tracksAssignment: Optional[bool] = None, tracksExposure: Optional[bool] = None): + """ + Fetch Options + Parameters: + tracksAssignment (Optional[bool]): Whether to track the assignment. The default None uses the server's default behavior (track the assignment event). + tracksExposure (Optional[bool]): Whether to track the exposure. The default None uses the server's default behavior (don't track the exposure event). + """ + self.tracksAssignment = tracksAssignment + self.tracksExposure = tracksExposure + + def __str__(self): + return f"FetchOptions(tracksAssignment={self.tracksAssignment}, tracksExposure={self.tracksExposure})" diff --git a/tests/remote/client_test.py b/tests/remote/client_test.py index a051249..43fcf24 100644 --- a/tests/remote/client_test.py +++ b/tests/remote/client_test.py @@ -1,3 +1,4 @@ +import json import unittest from unittest import mock @@ -5,6 +6,7 @@ from src.amplitude_experiment import RemoteEvaluationClient, Variant, User, RemoteEvaluationConfig from src.amplitude_experiment.exception import FetchException +from src.amplitude_experiment.remote.fetch_options import FetchOptions API_KEY = 'client-DvWljIjiiuqLbyjqdvBaLFfEBrAvGuA3' SERVER_URL = 'https://api.lab.amplitude.com/sdk/vardata' @@ -46,6 +48,39 @@ def test_fetch_failed_with_retry(self): variants = client.fetch(user) self.assertEqual({}, variants) + def test_fetch_with_fetch_options(self): + with RemoteEvaluationClient(API_KEY) as client: + user = User(user_id='test_user') + + mock_conn = mock.MagicMock() + client._connection_pool.acquire = lambda: mock_conn + mock_conn.request.return_value = mock.MagicMock(status=200) + mock_conn.request.return_value.read.return_value = json.dumps({ + 'sdk-ci-test': { + 'key': 'on' + } + }).encode('utf8') + + variants = client.fetch_v2(user, FetchOptions(tracksAssignment=False, tracksExposure=True)) + self.assertIn('sdk-ci-test', variants) + mock_conn.request.assert_called_once_with('POST', '/sdk/v2/vardata?v=0', mock.ANY, { + 'Authorization': f"Api-Key {API_KEY}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Track': 'no-track', + 'X-Amp-Exp-Exposure-Track': 'track' + }) + + mock_conn.request.reset_mock() + + variants = client.fetch_v2(user, FetchOptions(tracksAssignment=True, tracksExposure=False)) + self.assertIn('sdk-ci-test', variants) + mock_conn.request.assert_called_once_with('POST', '/sdk/v2/vardata?v=0', mock.ANY, { + 'Authorization': f"Api-Key {API_KEY}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Track': 'track', + 'X-Amp-Exp-Exposure-Track': 'no-track' + }) + @parameterized.expand([ (300, "Fetch Exception 300", True), (400, "Fetch Exception 400", False), @@ -63,8 +98,8 @@ def test_fetch_retry_with_response(self, response_code, error_message, should_ca mock_do_fetch.side_effect = FetchException(response_code, error_message) instance = RemoteEvaluationClient(API_KEY, RemoteEvaluationConfig(fetch_retries=1)) user = User(user_id='test_user') - instance.fetch(user) - mock_do_fetch.assert_called_once_with(user) + instance.fetch_v2(user) + mock_do_fetch.assert_called_once_with(user, None) self.assertEqual(should_call_retry, mock_retry_fetch.called)