diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml index 61c8d41..9d723be 100644 --- a/.github/workflows/typecheck.yml +++ b/.github/workflows/typecheck.yml @@ -31,8 +31,7 @@ jobs: - name: Install dependencies run: uv sync --all-extras - - name: Run mypy + - name: Run ty run: | cd litecli - uv run --no-sync --frozen -- python -m ensurepip - uv run --no-sync --frozen -- python -m mypy --no-pretty --install-types --non-interactive . + uv run ty check -v diff --git a/AGENTS.md b/AGENTS.md index 9b3d6b5..8009d19 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -20,12 +20,11 @@ - Lint: `ruff check` (add `--fix` to auto-fix) - Format: `ruff format` -### Mypy (type checking) -- Repo-wide (recommended): `mypy --explicit-package-bases .` -- Per-package: `mypy --explicit-package-bases litecli` +## ty (type checking) +- Repo-wide `ty check -v` +- Per-package: `ty check litecli -v` - Notes: - Config is in `pyproject.toml` (target Python 3.9, stricter settings). - - Use `--explicit-package-bases` to avoid module discovery issues when running outside tox. ## Coding Style & Naming Conventions - Formatter/linter: Ruff (configured via `.pre-commit-config.yaml` and `tox`). diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b57b59..b854df8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## 1.18.0 + +### Internal + +- Switch mypy to ty for type checking. [(#242)](https://github.com/dbcli/litecli/pull/242/files) + ## 1.17.0 - 2025-09-28 ### Features diff --git a/litecli/clistyle.py b/litecli/clistyle.py index 1bef2f1..b364872 100644 --- a/litecli/clistyle.py +++ b/litecli/clistyle.py @@ -1,20 +1,20 @@ from __future__ import annotations import logging - +from typing import cast import pygments.styles -from pygments.token import string_to_tokentype, Token -from pygments.style import Style as PygmentsStyle -from pygments.util import ClassNotFound +from prompt_toolkit.styles import Style, merge_styles from prompt_toolkit.styles.pygments import style_from_pygments_cls -from prompt_toolkit.styles import merge_styles, Style from prompt_toolkit.styles.style import _MergedStyle +from pygments.style import Style as PygmentsStyle +from pygments.token import Token, _TokenType, string_to_tokentype +from pygments.util import ClassNotFound logger = logging.getLogger(__name__) # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). -TOKEN_TO_PROMPT_STYLE: dict[Token, str] = { +TOKEN_TO_PROMPT_STYLE: dict[_TokenType, str] = { Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", Token.Menu.Completions.Completion: "completion-menu.completion", Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", @@ -43,10 +43,10 @@ } # reverse dict for cli_helpers, because they still expect Pygments tokens. -PROMPT_STYLE_TO_TOKEN: dict[str, Token] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} +PROMPT_STYLE_TO_TOKEN: dict[str, _TokenType] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} -def parse_pygments_style(token_name: str, style_object: PygmentsStyle | dict, style_dict: dict[str, str]) -> tuple[Token, str]: +def parse_pygments_style(token_name: str, style_object: PygmentsStyle | dict, style_dict: dict[str, str]) -> tuple[_TokenType, str]: """Parse token type and style string. :param token_name: str name of Pygments token. Example: "Token.String" @@ -111,4 +111,5 @@ class OutputStyle(PygmentsStyle): default_style = "" styles = style - return OutputStyle + # mypy does not complain but ty complains: error[invalid-return-type]: Return type does not match returned value. Hence added cast. + return cast(OutputStyle, PygmentsStyle) diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py index 465ab34..4e76faa 100644 --- a/litecli/completion_refresher.py +++ b/litecli/completion_refresher.py @@ -1,11 +1,10 @@ from __future__ import annotations import threading -from typing import Callable - -from .packages.special.main import COMMANDS from collections import OrderedDict +from typing import Callable, cast +from .packages.special.main import COMMANDS from .sqlcompleter import SQLCompleter from .sqlexecute import SQLExecute @@ -77,7 +76,9 @@ def _bg_refresh( # If callbacks is a single function then push it into a list. if callable(callbacks): - callbacks = [callbacks] + callbacks_list: list[Callable] = [callbacks] + else: + callbacks_list = list(cast(list[Callable], callbacks)) while 1: for refresher in self.refreshers.values(): @@ -94,7 +95,7 @@ def _bg_refresh( # break statement. continue - for callback in callbacks: + for callback in callbacks_list: callback(completer) diff --git a/litecli/config.py b/litecli/config.py index d93bb37..953bc16 100644 --- a/litecli/config.py +++ b/litecli/config.py @@ -1,11 +1,10 @@ from __future__ import annotations import errno -import shutil import os import platform -from os.path import expanduser, exists, dirname - +import shutil +from os.path import dirname, exists, expanduser from configobj import ConfigObj @@ -55,7 +54,7 @@ def upgrade_config(config: str, def_config: str) -> None: def get_config(liteclirc_file: str | None = None) -> ConfigObj: from litecli import __file__ as package_root - package_root = os.path.dirname(package_root) + package_root = os.path.dirname(str(package_root)) liteclirc_file = liteclirc_file or f"{config_location()}config" diff --git a/litecli/main.py b/litecli/main.py index c5796f7..b05577c 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -13,17 +13,17 @@ from io import open try: - from sqlean import OperationalError, sqlite_version + from sqlean import OperationalError, sqlite_version # type: ignore[import-untyped] except ImportError: from sqlite3 import OperationalError, sqlite_version from time import time -from typing import Any, Iterable +from typing import Any, Generator, Iterable, cast import click import sqlparse from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors from prompt_toolkit.auto_suggest import AutoSuggestFromHistory -from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.completion import Completion, DynamicCompleter from prompt_toolkit.document import Document from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.filters import HasFocus, IsDone @@ -35,8 +35,6 @@ ) from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.shortcuts import CompleteStyle, PromptSession -from typing import cast -from prompt_toolkit.completion import Completion from .__init__ import __version__ from .clibuffer import cli_is_multiline @@ -53,8 +51,6 @@ from .sqlcompleter import SQLCompleter from .sqlexecute import SQLExecute -click.disable_unicode_literals_warning = True - # Query tuples are used for maintaining history Query = namedtuple("Query", ["query", "successful", "mutating"]) @@ -84,7 +80,8 @@ def __init__( self.key_bindings = c["main"]["key_bindings"] special.set_favorite_queries(self.config) self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) - self.formatter.litecli = self + # self.formatter.litecli = self, ty raises unresolved-attribute, hence use dynamic assignment + setattr(self.formatter, "litecli", self) self.syntax_style = c["main"]["syntax_style"] self.less_chatty = c["main"].as_bool("less_chatty") self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") @@ -181,7 +178,7 @@ def register_special_commands(self) -> None: case_sensitive=True, ) - def change_table_format(self, arg: str, **_: Any) -> Iterable[tuple]: + def change_table_format(self, arg: str, **_: Any) -> Generator[tuple[None, None, None, str], None, None]: try: self.formatter.format_name = arg yield (None, None, None, "Changed table format to {}".format(arg)) @@ -200,11 +197,14 @@ def change_db(self, arg: str | None, **_: Any) -> Iterable[tuple]: self.sqlexecute.connect(database=arg) self.refresh_completions() + # guard so that ty doesn't complain + dbname = self.sqlexecute.dbname if self.sqlexecute is not None else "" + yield ( None, None, None, - 'You are now connected to database "%s"' % (self.sqlexecute.dbname), + 'You are now connected to database "%s"' % (dbname), ) def execute_from_file(self, arg: str | None, **_: Any) -> Iterable[tuple[Any, ...]]: @@ -303,7 +303,7 @@ def get(key: str) -> str | None: return {x: get(x) for x in keys} - def connect(self, database: str = "") -> None: + def connect(self, database: str | None = "") -> None: cnf: dict[str, str | None] = {"database": None} cnf = self.read_my_cnf_files(cnf.keys()) @@ -510,7 +510,8 @@ def one_iteration(text: str | None = None) -> None: successful = False start = time() res = sqlexecute.run(text) - self.formatter.query = text + # Set query attribute dynamically on formatter + setattr(self.formatter, "query", text) successful = True special.unset_once_if_written() # Keep track of whether or not the query is mutating. In case @@ -522,7 +523,8 @@ def one_iteration(text: str | None = None) -> None: raise e except KeyboardInterrupt: try: - sqlexecute.conn.interrupt() + # since connection can be sqlite3 or sqlean, it's hard to annotate the type for interrupt. so ignore the type hint warning. + sqlexecute.conn.interrupt() # type: ignore[attr-defined] except Exception as e: self.echo( "Encountered error while cancelling query: {}".format(e), @@ -755,6 +757,7 @@ def refresh_completions(self, reset: bool = False) -> list[tuple]: if reset: with self._completer_lock: self.completer.reset_completions() + assert self.sqlexecute is not None self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, @@ -815,7 +818,7 @@ def run_query(self, query: str, new_line: bool = True) -> None: results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result - self.formatter.query = query + setattr(self.formatter, "query", query) output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py index 86a16f6..1a5cd6d 100644 --- a/litecli/packages/parseutils.py +++ b/litecli/packages/parseutils.py @@ -4,8 +4,8 @@ from typing import Generator, Iterable, Literal import sqlparse -from sqlparse.sql import IdentifierList, Identifier, Function, Token, TokenList -from sqlparse.tokens import Keyword, DML, Punctuation +from sqlparse.sql import Function, Identifier, IdentifierList, Token, TokenList +from sqlparse.tokens import DML, Keyword, Punctuation cleanup_regex: dict[str, re.Pattern[str]] = { # This matches only alphanumerics and underscores. @@ -18,10 +18,10 @@ "all_punctuations": re.compile(r"([^\s]+)$"), } +LAST_WORD_INCLUDE_TYPE = Literal["alphanum_underscore", "many_punctuations", "most_punctuations", "all_punctuations"] -def last_word( - text: str, include: Literal["alphanum_underscore", "many_punctuations", "most_punctuations", "all_punctuations"] = "alphanum_underscore" -) -> str: + +def last_word(text: str, include: LAST_WORD_INCLUDE_TYPE = "alphanum_underscore") -> str: R""" Find the last word in a sentence. diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py index d50137e..410f25e 100644 --- a/litecli/packages/special/__init__.py +++ b/litecli/packages/special/__init__.py @@ -1,6 +1,7 @@ # ruff: noqa from __future__ import annotations +from types import FunctionType from typing import Callable, Any @@ -9,11 +10,34 @@ def export(defn: Callable[..., Any]) -> Callable[..., Any]: """Decorator to explicitly mark functions that are exposed in a lib.""" - globals()[defn.__name__] = defn - __all__.append(defn.__name__) + # ty, requires explict check for callable of tyep | function type to access __name__ + if isinstance(defn, (type, FunctionType)): + globals()[defn.__name__] = defn + __all__.append(defn.__name__) return defn from . import dbcommands from . import iocommands from . import llm +from . import utils +from .main import CommandNotFound, register_special_command, execute +from .iocommands import ( + set_favorite_queries, + editor_command, + get_filename, + get_editor_query, + open_external_editor, + is_expanded_output, + set_expanded_output, + write_tee, + unset_once_if_written, + unset_pipe_once_if_written, + disable_pager, + set_pager, + is_pager_enabled, + write_once, + write_pipe_once, + close_tee, +) +from .llm import is_llm_command, handle_llm, FinishIteration diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py index 71bc348..3dd2a89 100644 --- a/litecli/packages/special/favoritequeries.py +++ b/litecli/packages/special/favoritequeries.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import builtins from typing import Any, cast @@ -39,7 +40,7 @@ class FavoriteQueries(object): def __init__(self, config: Any) -> None: self.config = config - def list(self) -> list[str]: + def list(self) -> builtins.list[str]: section = cast(dict[str, str], self.config.get(self.section_name, {})) return list(section.keys()) diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index be6067e..434a96f 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -195,6 +195,7 @@ def execute_favorite_query(cur: Any, arg: str, verbose: bool = False, **_: Any) if arg_error: yield (None, None, None, arg_error) else: + assert query, "query should be non-empty" for sql in sqlparse.split(query): sql = sql.rstrip(";") title = "> %s" % (sql) if verbose else None diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 1b167ff..e2d7efa 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -25,7 +25,8 @@ LLM_TEMPLATE_NAME = "litecli-llm-template" LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys()) # Mapping of model_id to None used for completion tree leaves. -MODELS: dict[str, None] = {x.model_id: None for x in llm.get_models()} +# the file name is llm.py and module name is llm, hence ty is complaining that get_models is missing. +MODELS: dict[str, None] = {x.model_id: None for x in llm.get_models()} # type: ignore[attr-defined] def run_external_cmd( diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py index f320dec..0263447 100644 --- a/litecli/sqlcompleter.py +++ b/litecli/sqlcompleter.py @@ -1,18 +1,18 @@ from __future__ import annotations import logging -from re import compile, escape from collections import Counter +from re import compile, escape from typing import Any, Collection, Generator, Iterable, Literal, Sequence from prompt_toolkit.completion import CompleteEvent, Completer, Completion from prompt_toolkit.completion.base import Document from .packages.completion_engine import suggest_type -from .packages.parseutils import last_word -from .packages.special.iocommands import favoritequeries +from .packages.filepaths import complete_path, parse_path, suggest_path +from .packages.parseutils import LAST_WORD_INCLUDE_TYPE, last_word from .packages.special import llm -from .packages.filepaths import parse_path, complete_path, suggest_path +from .packages.special.iocommands import favoritequeries _logger = logging.getLogger(__name__) @@ -381,7 +381,7 @@ def extend_functions(self, func_data: Iterable[Sequence[str]]) -> None: metadata[self.dbname][func[0]] = None self.all_completions.add(func[0]) - def set_dbname(self, dbname: str) -> None: + def set_dbname(self, dbname: str | None) -> None: self.dbname = dbname def reset_completions(self) -> None: @@ -397,7 +397,7 @@ def find_matches( start_only: bool = False, fuzzy: bool = True, casing: str | None = None, - punctuations: str = "most_punctuations", + punctuations: LAST_WORD_INCLUDE_TYPE = "most_punctuations", ) -> Generator[Completion, None, None]: """Find completion matches for the given text. diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py index 889915f..14f15a0 100644 --- a/litecli/sqlexecute.py +++ b/litecli/sqlexecute.py @@ -1,25 +1,24 @@ from __future__ import annotations import logging -from typing import Any, Generator, Iterable - from contextlib import closing +from typing import Any, Generator, Iterable try: - import sqlean as sqlite3 - from sqlean import OperationalError + import sqlean as sqlite3 # type: ignore[import-untyped] + from sqlean import OperationalError # type: ignore[import-untyped] sqlite3.extensions.enable_all() except ImportError: import sqlite3 from sqlite3 import OperationalError -from litecli.packages.special.utils import check_if_sqlitedotcommand - -import sqlparse import os.path from urllib.parse import urlparse -from .packages import special +import sqlparse + +from litecli.packages import special +from litecli.packages.special.utils import check_if_sqlitedotcommand _logger = logging.getLogger(__name__) @@ -88,7 +87,8 @@ def connect(self, database: str | None = None) -> None: if not os.path.exists(db_dir_name): raise Exception("Path does not exist: {}".format(db_dir_name)) - conn = sqlite3.connect(database=db_name, isolation_level=None, uri=uri) + # sqlean exposes the connect method during run-time + conn = sqlite3.connect(database=db_name, isolation_level=None, uri=uri) # type: ignore[attr-defined] conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace") if self.conn: self.conn.close() diff --git a/pyproject.toml b/pyproject.toml index 04ab2b0..1b619b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,13 +11,12 @@ dependencies = [ "cli-helpers[styles]>=2.2.1", "click>=4.1,!=8.1.*", "configobj>=5.0.5", - "mypy>=1.17.1", "prompt-toolkit>=3.0.3,<4.0.0", "pygments>=1.6", "sqlparse>=0.4.4", - "setuptools", # Required by llm commands to install models + "setuptools", # Required by llm commands to install models "pip", - "llm>=0.25.0", + "llm>=0.25.0" ] [build-system] @@ -46,6 +45,7 @@ dev = [ "tox>=4.8.0", "pdbpp>=0.10.3", "llm>=0.19.0", + "ty>=0.0.4", ] [tool.setuptools.packages.find] @@ -57,33 +57,18 @@ litecli = ["liteclirc", "AUTHORS"] [tool.ruff] line-length = 140 -[tool.mypy] -pretty = true -strict_equality = true -ignore_missing_imports = true -warn_unreachable = true -warn_redundant_casts = true -warn_no_return = true -warn_unused_configs = true -show_column_numbers = true -show_error_codes = true -warn_unused_ignores = true -python_version = "3.9" -# Resolve module discovery reliably -explicit_package_bases = true -packages = ["litecli"] -# Gradually tighten typing -disallow_incomplete_defs = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true +[tool.ty.environment] +python-version = "3.9" +root = [".", "litecli", "litecli/packages", "litecli/packages/special"] + + +[tool.ty.src] exclude = [ - '^build/', - '^dist/', - '^\.tox/', - '^\.venv/', - '^venv/', - '^\.mypy_cache/', - '^\.pytest_cache/', - '^\.ruff_cache/', + '**/build/', + '**/dist/', + '**/.tox/', + '**/.venv/', + '**/.mypy_cache/', + '**/.pytest_cache/', + '**/.ruff_cache/' ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/conftest.py b/tests/conftest.py index a29201d..f7bf346 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,13 @@ -# mypy: ignore-errors - from __future__ import print_function import os import pytest -from utils import create_db, db_connection, drop_tables import litecli.sqlexecute +from .utils import create_db, db_connection, drop_tables + @pytest.fixture(scope="function") def connection(): diff --git a/tests/test_clistyle.py b/tests/test_clistyle.py index 95c48e0..00d6f5a 100644 --- a/tests/test_clistyle.py +++ b/tests/test_clistyle.py @@ -1,9 +1,8 @@ # -*- coding: utf-8 -*- -# mypy: ignore-errors + """Test the litecli.clistyle module.""" import pytest - from pygments.style import Style from pygments.token import Token @@ -17,7 +16,7 @@ def test_style_factory(): cli_style = {"Token.Output.Header": header} style = style_factory("default", cli_style) - assert isinstance(style(), Style) + assert isinstance(style, Style) assert Token.Output.Header in style.styles assert header == style.styles[Token.Output.Header] @@ -27,4 +26,4 @@ def test_style_factory_unknown_name(): """Test that an unrecognized name will not throw an error.""" style = style_factory("foobar", {}) - assert isinstance(style(), Style) + assert isinstance(style, Style) diff --git a/tests/test_completion_engine.py b/tests/test_completion_engine.py index 35ced44..f2a8a76 100644 --- a/tests/test_completion_engine.py +++ b/tests/test_completion_engine.py @@ -1,7 +1,6 @@ -# mypy: ignore-errors +import pytest from litecli.packages.completion_engine import suggest_type -import pytest def sorted_dicts(dicts): diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py index 6e3a7c5..32da6bd 100644 --- a/tests/test_completion_refresher.py +++ b/tests/test_completion_refresher.py @@ -1,9 +1,8 @@ -# mypy: ignore-errors - import time -import pytest from unittest.mock import Mock, patch +import pytest + @pytest.fixture def refresher(): diff --git a/tests/test_dbspecial.py b/tests/test_dbspecial.py index b2e731c..65e33d0 100644 --- a/tests/test_dbspecial.py +++ b/tests/test_dbspecial.py @@ -1,21 +1,19 @@ -# mypy: ignore-errors - -from test_completion_engine import sorted_dicts -from utils import assert_result_equal, dbtest, run - from litecli.packages.completion_engine import suggest_type from litecli.packages.special.utils import check_if_sqlitedotcommand, format_uptime +from .test_completion_engine import sorted_dicts +from .utils import assert_result_equal, dbtest, run + def test_import_first_argument(): - test_cases = [ + test_cases: list[tuple[str, int]] = [ # text, expecting_arg_idx - [".import ", 1], - [".import ./da", 1], - [".import ./data.csv ", 2], - [".import ./data.csv t", 2], - [".import ./data.csv `t", 2], - ['.import ./data.csv "t', 2], + (".import ", 1), + (".import ./da", 1), + (".import ./data.csv ", 2), + (".import ./data.csv t", 2), + (".import ./data.csv `t", 2), + ('.import ./data.csv "t', 2), ] for text, expecting_arg_idx in test_cases: suggestions = suggest_type(text, text) @@ -53,20 +51,14 @@ def test_list_or_show_create_tables(): def test_format_uptime(): - seconds = 59 - assert "59 sec" == format_uptime(seconds) - - seconds = 120 - assert "2 min 0 sec" == format_uptime(seconds) - - seconds = 54890 - assert "15 hours 14 min 50 sec" == format_uptime(seconds) - - seconds = 598244 - assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds) - - seconds = 522600 - assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds) + for seconds, human_readable_text in [ + ("59", "59 sec"), + ("120", "2 min 0 sec"), + ("54890", "15 hours 14 min 50 sec"), + ("598244", "6 days 22 hours 10 min 44 sec"), + ("522600", "6 days 1 hour 10 min 0 sec"), + ]: + assert human_readable_text == format_uptime(seconds) def test_indexes(): diff --git a/tests/test_llm_special.py b/tests/test_llm_special.py index a7477f7..d7de461 100644 --- a/tests/test_llm_special.py +++ b/tests/test_llm_special.py @@ -1,8 +1,8 @@ -# mypy: ignore-errors +from unittest.mock import patch import pytest -from unittest.mock import patch -from litecli.packages.special.llm import handle_llm, FinishIteration, USAGE + +from litecli.packages.special.llm import USAGE, FinishIteration, handle_llm @patch("litecli.packages.special.llm.llm") diff --git a/tests/test_main.py b/tests/test_main.py index 787233c..4c880e7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,19 +1,19 @@ -# mypy: ignore-errors - import os -from collections import namedtuple -from textwrap import dedent import shutil +from collections import namedtuple from datetime import datetime +from textwrap import dedent from unittest.mock import patch import click import pytest from click.testing import CliRunner +from prompt_toolkit import PromptSession -from litecli.main import cli, LiteCli +from litecli.main import LiteCli, cli from litecli.packages.special.main import COMMANDS as SPECIAL_COMMANDS -from utils import dbtest, run, create_db, db_connection + +from .utils import create_db, db_connection, dbtest, run test_dir = os.path.abspath(os.path.dirname(__file__)) project_dir = os.path.dirname(test_dir) @@ -21,6 +21,8 @@ CLI_ARGS = ["--liteclirc", default_config_file, "_test_db"] +clickoutput: str + @dbtest def test_execute_arg(executor): @@ -135,6 +137,7 @@ def test_help_strings_end_with_periods(): for param in cli.params: if isinstance(param, click.core.Option): assert hasattr(param, "help") + assert isinstance(param.help, str) assert param.help.endswith(".") @@ -158,7 +161,7 @@ class TestExecute: def server_type(self): return ["test"] - class PromptBuffer: + class PromptBuffer(PromptSession): output = TestOutput() m.prompt_app = PromptBuffer() @@ -237,10 +240,10 @@ def stub_terminal_size(): old_func = shutil.get_terminal_size - shutil.get_terminal_size = stub_terminal_size + shutil.get_terminal_size = stub_terminal_size # type: ignore[assignment] lc = LiteCli() assert isinstance(lc.get_reserved_space(), int) - shutil.get_terminal_size = old_func + shutil.get_terminal_size = old_func # type: ignore[assignment] @dbtest @@ -266,6 +269,7 @@ def test_import_command(executor): def test_startup_commands(executor): m = LiteCli(liteclirc=default_config_file) + assert m.startup_commands assert m.startup_commands["commands"] == [ "create table startupcommands(a text)", "insert into startupcommands values('abc')", @@ -326,7 +330,8 @@ def test_get_prompt(mock_datetime): assert lc.get_prompt(r"\s") == "42" # 11. Test when dbname is None => (none) - lc.connect(None) # Simulate no DB connection + lc.connect(None) + # Simulate no DB connection and incorrect argument type assert lc.get_prompt(r"\d") == "(none)" assert lc.get_prompt(r"\f") == "(none)" diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py index 52cd3b5..fdb3317 100644 --- a/tests/test_parseutils.py +++ b/tests/test_parseutils.py @@ -1,11 +1,10 @@ -# mypy: ignore-errors - import pytest + from litecli.packages.parseutils import ( extract_tables, - query_starts_with, - queries_start_with, is_destructive, + queries_start_with, + query_starts_with, ) diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py index 51c4361..2de74ce 100644 --- a/tests/test_prompt_utils.py +++ b/tests/test_prompt_utils.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# mypy: ignore-errors + import click diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index f4d32ce..30d19a6 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -1,8 +1,10 @@ # coding: utf-8 -# mypy: ignore-errors + from __future__ import unicode_literals -import pytest + from unittest.mock import patch + +import pytest from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document diff --git a/tests/test_special_iocommands.py b/tests/test_special_iocommands.py index b4f1e5b..e661b48 100644 --- a/tests/test_special_iocommands.py +++ b/tests/test_special_iocommands.py @@ -1,12 +1,10 @@ -# mypy: ignore-errors - import os import tempfile import pytest import litecli.packages.special -from litecli.packages.special.main import parse_special_command, Verbosity +from litecli.packages.special.main import Verbosity, parse_special_command def test_once_command(): diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py index 57693cf..e1e5539 100644 --- a/tests/test_sqlexecute.py +++ b/tests/test_sqlexecute.py @@ -1,14 +1,13 @@ # coding=UTF-8 -# mypy: ignore-errors import os import pytest -from utils import run, dbtest, set_expanded_output, is_expanded_output, assert_result_equal +from .utils import assert_result_equal, dbtest, is_expanded_output, run, set_expanded_output try: - from sqlean import OperationalError, ProgrammingError + from sqlean import OperationalError, ProgrammingError # type: ignore[import-untyped] except ImportError: from sqlite3 import OperationalError, ProgrammingError diff --git a/tests/utils.py b/tests/utils.py index a097f64..4bb7b48 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- -# mypy: ignore-errors + +import multiprocessing import os -import time import signal -import platform -import multiprocessing +import sys +import time from contextlib import closing + import pytest try: @@ -20,7 +21,7 @@ def db_connection(dbname=":memory:"): - conn = sqlite3.connect(database=dbname, isolation_level=None) + conn = sqlite3.connect(database=dbname, isolation_level=None) # type: ignore[attr-defined] return conn @@ -75,8 +76,8 @@ def send_ctrl_c_to_pid(pid, wait_seconds): """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds` seconds.""" time.sleep(wait_seconds) - system_name = platform.system() - if system_name == "Windows": + # ty, is aware of sys.platform and not platform.system. See: https://github.com/astral-sh/ty/issues/2033 + if sys.platform == "win32": os.kill(pid, signal.CTRL_C_EVENT) else: os.kill(pid, signal.SIGINT)