Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Release Notes

.. Upcoming Version

* Fix warning when multiplying variables with pd.Series containing time-zone aware index
* Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()``
* Add simplify method to LinearExpression to combine duplicate terms
* Add convenience function to create LinearExpression from constant
Expand Down
70 changes: 67 additions & 3 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from functools import partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypeVar, overload
from warnings import warn

import numpy as np
import pandas as pd
import polars as pl
import xarray as xr
from numpy import arange, signedinteger
from xarray import DataArray, Dataset, apply_ufunc, broadcast
from xarray import align as xr_align
Expand Down Expand Up @@ -45,6 +46,48 @@
from linopy.variables import Variable


class CoordAlignWarning(UserWarning): ...


class TimezoneAlignError(ValueError): ...


P = ParamSpec("P")
R = TypeVar("R")


class CatchDatetimeTypeError:
"""Context manager that catches datetime-related TypeErrors and re-raises as TimezoneAlignError."""

def __enter__(self) -> CatchDatetimeTypeError:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
if exc_type is TypeError and exc_val is not None:
if "Cannot interpret 'datetime" in str(exc_val):
raise TimezoneAlignError(
"Timezone information across datetime coordinates not aligned."
) from exc_val
return False


def catch_datetime_type_error_and_re_raise(func: Callable[P, R]) -> Callable[P, R]:
"""Decorator that catches datetime-related TypeErrors and re-raises as TimezoneAlignError."""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
with CatchDatetimeTypeError():
result = func(*args, **kwargs)
return result

return wrapper


def set_int_index(series: pd.Series) -> pd.Series:
"""
Convert string index to int index.
Expand Down Expand Up @@ -128,6 +171,21 @@ def get_from_iterable(lst: DimsLike | None, index: int) -> Any | None:
return lst[index] if 0 <= index < len(lst) else None


def try_to_convert_to_pd_datetime_index(
coord: xr.DataArray | Sequence | pd.Index | Any,
) -> pd.DatetimeIndex | xr.DataArray | Sequence | pd.Index | Any:
if isinstance(coord, pd.DatetimeIndex):
return coord
try:
if isinstance(coord, xr.DataArray):
index = coord.to_index()
assert isinstance(index, pd.DatetimeIndex)
return index
return pd.DatetimeIndex(coord)
except Exception:
return coord


def pandas_to_dataarray(
arr: pd.DataFrame | pd.Series,
coords: CoordsLike | None = None,
Expand Down Expand Up @@ -168,7 +226,10 @@ def pandas_to_dataarray(
shared_dims = set(pandas_coords.keys()) & set(coords.keys())
non_aligned = []
for dim in shared_dims:
pd_coord = pandas_coords[dim]
coord = coords[dim]
if isinstance(pd_coord, pd.DatetimeIndex):
coord = try_to_convert_to_pd_datetime_index(coord)
if not isinstance(coord, pd.Index):
coord = pd.Index(coord)
if not pandas_coords[dim].equals(coord):
Expand All @@ -178,7 +239,8 @@ def pandas_to_dataarray(
f"coords for dimension(s) {non_aligned} is not aligned with the pandas object. "
"Previously, the indexes of the pandas were ignored and overwritten in "
"these cases. Now, the pandas object's coordinates are taken considered"
" for alignment."
" for alignment.",
CoordAlignWarning,
)

return DataArray(arr, coords=None, dims=dims, **kwargs)
Expand Down Expand Up @@ -449,6 +511,7 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
return df


@catch_datetime_type_error_and_re_raise
def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
"""
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
Expand All @@ -458,14 +521,15 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
except ValueError:
warn(
"Coordinates across variables not equal. Perform outer join.",
UserWarning,
CoordAlignWarning,
)
arrs = xr_align(*dataarrays, join="outer")
if integer_dtype:
arrs = tuple([ds.fillna(-1).astype(int) for ds in arrs])
return Dataset({ds.name: ds for ds in arrs})


@catch_datetime_type_error_and_re_raise
def assign_multiindex_safe(ds: Dataset, **fields: Any) -> Dataset:
"""
Assign a field to a xarray Dataset while being safe against warnings about multiindex corruption.
Expand Down
3 changes: 3 additions & 0 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
LocIndexer,
as_dataarray,
assign_multiindex_safe,
catch_datetime_type_error_and_re_raise,
check_common_keys_values,
check_has_nulls,
check_has_nulls_polars,
Expand Down Expand Up @@ -505,6 +506,7 @@ def __neg__(self: GenericExpression) -> GenericExpression:
"""
return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const)

@catch_datetime_type_error_and_re_raise
def _multiply_by_linear_expression(
self, other: LinearExpression | ScalarLinearExpression
) -> QuadraticExpression:
Expand All @@ -526,6 +528,7 @@ def _multiply_by_linear_expression(
res = res + self.reset_const() * other.const
return res

@catch_datetime_type_error_and_re_raise
def _multiply_by_constant(
self: GenericExpression, other: ConstantLike
) -> GenericExpression:
Expand Down
2 changes: 2 additions & 0 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LocIndexer,
as_dataarray,
assign_multiindex_safe,
catch_datetime_type_error_and_re_raise,
check_has_nulls,
check_has_nulls_polars,
filter_nulls_polars,
Expand Down Expand Up @@ -295,6 +296,7 @@ def loc(self) -> LocIndexer:
def to_pandas(self) -> pd.Series:
return self.labels.to_pandas()

@catch_datetime_type_error_and_re_raise
def to_linexpr(
self,
coefficient: ConstantLike = 1,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ dev = [
"types-requests",
"gurobipy",
"highspy",
"types-pytz"
]
solvers = [
"gurobipy",
Expand Down
69 changes: 67 additions & 2 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
@author: fabian
"""

from datetime import datetime

import numpy as np
import pandas as pd
import polars as pl
import pytest
import xarray as xr
from pytz import UTC
from test_linear_expression import m, u, x # noqa: F401
from xarray import DataArray
from xarray.testing.assertions import assert_equal

from linopy import LinearExpression, Model, Variable
from linopy.common import (
CoordAlignWarning,
align,
as_dataarray,
assign_multiindex_safe,
Expand Down Expand Up @@ -73,6 +77,67 @@ def test_as_dataarray_with_series_dims_priority() -> None:
assert list(da.coords[target_dim].values) == target_index


def test_as_datarray_with_tz_aware_series_index() -> None:
time_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
other_index = pd.Index(name="time", data=[0, 1, 2, 3])

panda_series = pd.Series(index=time_index, data=1.0)

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index])
result = as_dataarray(arr=panda_series, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index])
with pytest.warns(CoordAlignWarning):
result = as_dataarray(arr=panda_series, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": time_index}
result = as_dataarray(arr=panda_series, coords=coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": [0, 1, 2, 3]}
result = as_dataarray(arr=panda_series, coords=coords)
assert time_index.equals(result.coords["time"].to_index())


def test_as_datarray_with_tz_aware_dataframe_columns_index() -> None:
time_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
other_index = pd.Index(name="time", data=[0, 1, 2, 3])

index = pd.Index([0, 1, 2, 3], name="x")
pandas_df = pd.DataFrame(index=index, columns=time_index, data=1.0)

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index])
result = as_dataarray(arr=pandas_df, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index])
with pytest.warns(CoordAlignWarning):
result = as_dataarray(arr=pandas_df, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": time_index}
result = as_dataarray(arr=pandas_df, coords=coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": [0, 1, 2, 3]}
result = as_dataarray(arr=pandas_df, coords=coords)
assert time_index.equals(result.coords["time"].to_index())


def test_as_dataarray_with_series_dims_subset() -> None:
target_dim = "dim_0"
target_index = ["a", "b", "c"]
Expand All @@ -99,7 +164,7 @@ def test_as_dataarray_with_series_override_coords() -> None:
target_dim = "dim_0"
target_index = ["a", "b", "c"]
s = pd.Series([1, 2, 3], index=target_index)
with pytest.warns(UserWarning):
with pytest.warns(CoordAlignWarning):
da = as_dataarray(s, coords=[[1, 2, 3]])
assert isinstance(da, DataArray)
assert da.dims == (target_dim,)
Expand Down Expand Up @@ -218,7 +283,7 @@ def test_as_dataarray_dataframe_override_coords() -> None:
target_index = ["a", "b"]
target_columns = ["A", "B"]
df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns)
with pytest.warns(UserWarning):
with pytest.warns(CoordAlignWarning):
da = as_dataarray(df, coords=[[1, 2], [2, 3]])
assert isinstance(da, DataArray)
assert da.dims == target_dims
Expand Down
28 changes: 28 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@

from __future__ import annotations

from datetime import datetime

import numpy as np
import pandas as pd
import polars as pl
import pytest
import xarray as xr
from pytz import UTC
from xarray.testing import assert_equal

from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge
from linopy.common import TimezoneAlignError
from linopy.constants import HELPER_DIMS, TERM_DIM
from linopy.expressions import ScalarLinearExpression
from linopy.testing import assert_linequal, assert_quadequal
Expand Down Expand Up @@ -1230,6 +1234,30 @@ def test_cumsum(m: Model, multiple: float) -> None:
cumsum.nterm == 2


def test_timezone_alignment_failure() -> None:
utc_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
tz_naive_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=None,
name="time",
)
model = Model()
series1 = pd.Series(index=tz_naive_index, data=1.0)
expr = model.add_variables(coords=[utc_index], name="var1") * 1.0

with pytest.raises(TimezoneAlignError):
# We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together
_ = expr * series1


def test_simplify_basic(x: Variable) -> None:
"""Test basic simplification with duplicate terms."""
expr = 2 * x + 3 * x + 1 * x
Expand Down
29 changes: 29 additions & 0 deletions test/test_quadratic_expression.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#!/usr/bin/env python3

from datetime import datetime

import numpy as np
import pandas as pd
import polars as pl
import pytest
from pytz import UTC
from scipy.sparse import csc_matrix
from xarray import DataArray

from linopy import Model, Variable, merge
from linopy.common import TimezoneAlignError
from linopy.constants import FACTOR_DIM, TERM_DIM
from linopy.expressions import LinearExpression, QuadraticExpression
from linopy.testing import assert_quadequal
Expand Down Expand Up @@ -360,3 +364,28 @@ def test_power_of_three(x: Variable) -> None:
x**3
with pytest.raises(TypeError):
(x * x) * (x * x)


def test_timezone_alignment_failure() -> None:
utc_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
tz_naive_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=None,
name="time",
)
model = Model()
series1 = pd.Series(index=tz_naive_index, data=1.0)
var = model.add_variables(coords=[utc_index], name="var1")
expr = var * var

with pytest.raises(TimezoneAlignError):
# We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together
_ = expr * series1
Loading