diff --git a/synthpop/ipf/ipf.py b/synthpop/ipf/ipf.py index b3beddb..67715e5 100644 --- a/synthpop/ipf/ipf.py +++ b/synthpop/ipf/ipf.py @@ -52,6 +52,12 @@ def calc_diff(x, y): iterations = 0 + sums = marginals.groupby(level=0).sum() + if sums.max() - sums.min() > 0.01: + raise RuntimeError( + 'Marginals do not add up ipf will not converge: {}'.format(sums)) + del sums + list_of_loc = [ ((flat_joint_dist[idx[0]] == idx[1]).values, marginals[idx]) for idx in marginals.index diff --git a/synthpop/ipf/test/test_ipf.py b/synthpop/ipf/test/test_ipf.py index f30142b..ce15175 100644 --- a/synthpop/ipf/test/test_ipf.py +++ b/synthpop/ipf/test/test_ipf.py @@ -45,3 +45,22 @@ def test_larger_ipf(): with pytest.raises(RuntimeError): ipf.calculate_constraints(marginals, joint_dist, max_iterations=2) + + +def test_not_add_ipf(): + # Test IPF with some data that's slightly more meaningful, + # but for which it's harder to know the actual correct answer. + marginal_midx = pd.MultiIndex.from_tuples( + [('cat_owner', 'yes'), + ('cat_owner', 'no'), + ('car_color', 'blue'), + ('car_color', 'red'), + ('car_color', 'green')]) + marginals = pd.Series([60, 40, 50, 31, 20], index=marginal_midx) + joint_dist_midx = pd.MultiIndex.from_product( + [('yes', 'no'), ('blue', 'red', 'green')], + names=['cat_owner', 'car_color']) + joint_dist = pd.Series([8, 4, 2, 5, 3, 2], index=joint_dist_midx) + + with pytest.raises(RuntimeError): + ipf.calculate_constraints(marginals, joint_dist) diff --git a/synthpop/recipes/starter.py b/synthpop/recipes/starter.py index 5c7bf39..a4b161b 100644 --- a/synthpop/recipes/starter.py +++ b/synthpop/recipes/starter.py @@ -69,10 +69,10 @@ def __init__(self, key, state, county, tract=None): "+ B19001_017E", ("cars", "none"): "B08201_002E", ("cars", "one"): "B08201_003E", - ("cars", "two or more"): "B08201_004E + B08201_005E + B08201_006E", + ("cars", "two or more"): "B08201_001E - (B08201_002E + B08201_003E)", ("workers", "none"): "B08202_002E", ("workers", "one"): "B08202_003E", - ("workers", "two or more"): "B08202_004E + B08202_005E" + ("workers", "two or more"): "B08201_001E - (B08202_002E + B08202_003E)" }, index_cols=['state', 'county', 'tract', 'block group']) population = ['B01001_001E'] diff --git a/synthpop/recipes/starter2.py b/synthpop/recipes/starter2.py index 6d79d80..50dbe2d 100644 --- a/synthpop/recipes/starter2.py +++ b/synthpop/recipes/starter2.py @@ -105,10 +105,10 @@ def __init__(self, key, state, county, tract=None): ("hh_cars", "none"): "B08201_002E", ("hh_cars", "one"): "B08201_003E", ("hh_cars", "two or more"): - "B08201_004E + B08201_005E + B08201_006E", + "B08201_001E - (B08201_002E - B08201_003E)", ("hh_workers", "none"): "B08202_002E", ("hh_workers", "one"): "B08202_003E", - ("hh_workers", "two or more"): "B08202_004E + B08202_005E", + ("hh_workers", "two or more"): "B08202_001E - (B08202_002E - B08202_003E)", ("tenure_mover", "own recent"): "B25038_003E", ("tenure_mover", "own not recent"): "B25038_002E - B25038_003E", ("tenure_mover", "rent recent"): "B25038_010E",