Skip to content

Commit 6651785

Browse files
Issue-5: Report invalid and below threshold combinations
1 parent 2b4dd5c commit 6651785

File tree

5 files changed

+38
-16
lines changed

5 files changed

+38
-16
lines changed

mbdiff/__main__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import click
22
from mbdiff.diff_query import DiffQuery
33
from mbdiff.diff import diff_file
4-
from mbdiff.pretty_print import present_explanations
4+
from mbdiff.pretty_print import present_explanations, present_invalid
55

66

77
@click.command()
@@ -13,10 +13,19 @@
1313
def main(data, min_support, min_risk, max_order, query):
1414
metric, op, value = query.split()
1515
query = DiffQuery(metric, op, value)
16-
explanations = diff_file(data, query, max_order, min_risk, min_support)
17-
print("Explanations")
18-
explanations = sorted(explanations, key=lambda x: x[0], reverse=True)
19-
print(present_explanations(explanations))
16+
explanations, invalid = diff_file(data, query, max_order, min_risk, min_support)
17+
if explanations:
18+
explanations = sorted(explanations, key=lambda x: x[0], reverse=True)
19+
print("Explanations")
20+
print(present_explanations(explanations))
21+
else:
22+
print("Could not find any explanations for this input")
23+
if invalid:
24+
print("Attribute combinations below thresholds")
25+
print(present_invalid(invalid))
26+
else:
27+
print("There were no invalid or below threshold attribute combinations")
28+
2029

2130

2231
if __name__ == "__main__":

mbdiff/diff.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from mbdiff.diff_query import DiffQuery
33
from mbdiff.risk_ratio import risk_ratio
44
from mbdiff.attribute_mining import get_combs
5+
from numpy import nan
56

67

78
def diff_file(
@@ -10,7 +11,7 @@ def diff_file(
1011
max_order: int,
1112
min_risk: float,
1213
min_support: float,
13-
) -> list:
14+
):
1415
"""Given a tab delimited file and a distinguishing metric return explanations."""
1516
df = read_csv(path_to_df)
1617
print("Outliers:")
@@ -20,17 +21,22 @@ def diff_file(
2021

2122
def diff(
2223
df: DataFrame, query: DiffQuery, max_order: int, min_risk: float, min_support: float
23-
) -> list:
24+
):
25+
"""Return explanations, invalid and below support criterium attribute combinations."""
2426
query.mark_groups(df)
2527
ignored_cols = ["outlier"]
2628
# ignore all non categorical columns
2729
for i, column in enumerate(df.columns):
2830
if df.dtypes[i] != "object":
2931
ignored_cols.append(column)
3032
combinations = get_combs(df, max_order, min_support, ignored_cols)
31-
results = []
33+
results, invalid = [], []
3234
for combination in combinations:
3335
rr = risk_ratio(combination, df)
34-
if rr >= min_risk:
35-
results.append((rr, combination))
36-
return results
36+
res = (rr, combination)
37+
if rr is nan or rr <= min_risk:
38+
invalid.append(combination)
39+
else:
40+
results.append(res)
41+
42+
return results, invalid

mbdiff/pretty_print.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Present explanations in a more approachable way."""
2-
from typing import List
2+
from typing import List, Dict
33
from tabulate import tabulate
44
from pandas import DataFrame
55

@@ -14,3 +14,8 @@ def present_explanations(explanations: List) -> str:
1414
# NaN means "any value", represent as "-" just like in the original paper
1515
pres_df.fillna("-", inplace=True)
1616
return tabulate(pres_df, headers="keys")
17+
18+
def present_invalid(combinations: List[Dict]) -> str:
19+
"""Pretty print invalid attr combinations."""
20+
pres_df = DataFrame(combinations)
21+
return tabulate(pres_df, headers="keys")

mbdiff/risk_ratio.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pandas import DataFrame
2+
from numpy import nan
23

34

45
def calc_support(df, column, value) -> float:
@@ -20,11 +21,11 @@ def risk_ratio(attr_combination: dict, df) -> float:
2021
top_d = a0 + ai
2122
denom_d = b0 + bi
2223
if top_d == 0 or denom_d == 0:
23-
return 0
24+
return nan
2425
top = a0 / top_d
2526
denom = b0 / denom_d
2627
if denom < 0.01:
27-
return 0.0
28+
return nan
2829
return top / denom
2930

3031

mbdiff/tests/test_risk_ratio.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from mbdiff.risk_ratio import calc_support, risk_ratio
22
import pytest
3+
from numpy import nan
34

45

56
def test_risk_basic_null_1(df_outliers):
6-
"""Due to lack of cat2 in the inliers the result will be 0.0."""
7+
"""Due to lack of cat2 in the inliers the result will be NaN."""
78
comb = {"cats": "cat2"}
8-
assert risk_ratio(comb, df_outliers) == pytest.approx(0.0)
9+
assert risk_ratio(comb, df_outliers) is nan
910

1011

1112
def test_risk_basic_null_2(df_outliers):

0 commit comments

Comments
 (0)