Skip to content

Commit 9c2ef00

Browse files
authored
Fix in growacq with implied constraints (#20)
* do not empty bias before return * Add implied constraints from bias to cl * fix when bias is given and not generated in growacq * Update .gitignore * test findc2 also on growacq * leftover input in nqueens
1 parent 2dc71ad commit 9c2ef00

10 files changed

Lines changed: 22 additions & 10 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ notebooks/.ipynb_checkpoints/Prediction-based CA system-checkpoint.ipynb
1515
dist/pycona-0.2.4-py3-none-any.whl
1616
dist/pycona-0.2.4.tar.gz
1717
notebooks/.ipynb_checkpoints/Comparing different algorithms and methods-checkpoint.ipynb
18+
testing.py

pycona/active_algorithms/genacq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6868
if self.env.verbose >= 1:
6969
print(f"\nLearned {self.env.metrics.cl} constraints in "
7070
f"{self.env.metrics.total_queries} queries.")
71-
self.env.instance.bias = []
7271
return self.env.instance
7372

7473
self.env.metrics.increase_generation_time(gen_end - gen_start)

pycona/active_algorithms/growacq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..answering_queries import Oracle, UserOracle
88
from .. import Metrics
99
from ..ca_environment import ProbaActiveCAEnv
10-
10+
from ..utils import get_con_subset
1111

1212
class GrowAcq(AlgorithmCAInteractive):
1313
"""
@@ -67,6 +67,11 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6767
print(f"\nGrowAcq: calling inner_algorithm for {len(Y)}/{n_vars} variables")
6868
self.env.instance = self.inner_algorithm.learn(self.env.instance, oracle, verbose=verbose, X=Y, metrics=self.env.metrics)
6969

70+
# Add implied constraints from bias to cl
71+
implied_constraints = get_con_subset(self.env.instance.bias, Y)
72+
self.env.instance.cl.extend(implied_constraints)
73+
self.env.instance.bias = [c for c in self.env.instance.bias if c not in set(implied_constraints)] # remove implied constraints from bias
74+
7075
if verbose >= 3:
7176
print("C_L: ", len(self.env.instance.cl))
7277
print("B: ", len(self.env.instance.bias))

pycona/active_algorithms/mineacq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6666
if self.env.verbose >= 1:
6767
print(f"\nLearned {self.env.metrics.cl} constraints in "
6868
f"{self.env.metrics.total_queries} queries.")
69-
self.env.instance.bias = []
7069
return self.env.instance
7170

7271
self.env.metrics.increase_generation_time(gen_end - gen_start)

pycona/active_algorithms/mquacq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6262
if self.env.verbose >= 1:
6363
print(f"\nLearned {self.env.metrics.cl} constraints in "
6464
f"{self.env.metrics.membership_queries_count} queries.")
65-
self.env.instance.bias = []
6665
return self.env.instance
6766

6867
self.env.metrics.increase_generation_time(gen_end - gen_start)

pycona/active_algorithms/mquacq2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6969
if self.env.verbose >= 1:
7070
print(f"\nLearned {self.env.metrics.cl} constraints in "
7171
f"{self.env.metrics.membership_queries_count} queries.")
72-
self.env.instance.bias = []
7372
return self.env.instance
7473

7574
self.env.metrics.increase_generated_queries()

pycona/active_algorithms/pquacq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
6262
if self.env.verbose >= 1:
6363
print(f"\nLearned {self.env.metrics.cl} constraints in "
6464
f"{self.env.metrics.membership_queries_count} queries.")
65-
self.env.instance.bias = []
6665
return self.env.instance
6766

6867
self.env.metrics.increase_generation_time(gen_end - gen_start)

pycona/active_algorithms/quacq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbos
5858
if self.env.verbose >= 1:
5959
print(f"\nLearned {self.env.metrics.cl} constraints in "
6060
f"{self.env.metrics.membership_queries_count} queries.")
61-
self.env.instance.bias = []
6261
return self.env.instance
6362

6463
self.env.metrics.increase_generation_time(gen_end - gen_start)

pycona/benchmarks/nqueens.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ..answering_queries.constraint_oracle import ConstraintOracle
44
from ..problem_instance import ProblemInstance, absvar
55

6-
def construct_nqueens_problem(n):
6+
def construct_nqueens_problem(n=8):
77

88
parameters = {"n": n}
99

@@ -43,6 +43,4 @@ def construct_nqueens_problem(n):
4343
for c in oracle.constraints:
4444
print(c)
4545

46-
input("Press Enter to continue...")
47-
4846
return instance, oracle

tests/test_finc.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,19 @@ def test_findc2_with_golomb4(self):
7777
learned_not_oracle += cp.any([~c for c in oracle.constraints])
7878
assert not learned_not_oracle.solve()
7979

80+
# test growacq
81+
alg = ca.GrowAcq(ca_env, alg)
82+
li2 = alg.learn(instance, oracle)
83+
84+
# oracle model imply learned?
85+
oracle_not_learned = cp.Model(oracle.constraints)
86+
oracle_not_learned += cp.any([~c for c in li2._cl])
87+
assert not oracle_not_learned.solve()
88+
89+
# learned model imply oracle?
90+
learned_not_oracle = cp.Model(li2._cl)
91+
learned_not_oracle += cp.any([~c for c in oracle.constraints])
92+
assert not learned_not_oracle.solve()
93+
8094

8195

0 commit comments

Comments
 (0)