Skip to content

Commit 2e7b90b

Browse files
committed
Fix for set hashes
Signed-off-by: James Hamlin <jfhamlin@gmail.com>
1 parent b30d514 commit 2e7b90b

File tree

5 files changed

+76
-53
lines changed

5 files changed

+76
-53
lines changed

pkg/compiler/analyze.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,7 @@ func (a *Analyzer) parseCaseStar(form interface{}, env Env) (*ast.Node, error) {
14381438
mapEntry := First(seq).(IMapEntry)
14391439
key := mapEntry.Key()
14401440
val := mapEntry.Val()
1441-
1441+
14421442
// Convert key to int64
14431443
var keyInt int64
14441444
switch k := key.(type) {
@@ -1457,20 +1457,20 @@ func (a *Analyzer) parseCaseStar(form interface{}, env Env) (*ast.Node, error) {
14571457
default:
14581458
return nil, exInfo(fmt.Sprintf("case* map key must be integer, got %T", key), nil)
14591459
}
1460-
1460+
14611461
// Extract the vector [test-constant result-expr]
14621462
if Count(val) != 2 {
14631463
return nil, exInfo("case* map value must be a 2-element vector", nil)
14641464
}
1465-
1465+
14661466
testConstant := First(val)
14671467
resultExpr := second(val)
1468-
1468+
14691469
// Check if this is a collision case
14701470
// In Clojure, entries whose keys are in skipCheck should be evaluated directly
14711471
// without comparison (they contain condp expressions for collision handling)
14721472
hasCollision := false
1473-
1473+
14741474
// Check if the map key is in the skip check set
14751475
switch k := key.(type) {
14761476
case int64:
@@ -1498,7 +1498,7 @@ func (a *Analyzer) parseCaseStar(form interface{}, env Env) (*ast.Node, error) {
14981498
hasCollision = true
14991499
}
15001500
}
1501-
1501+
15021502
// Analyze the test constant and result expression
15031503
var testConstantNode *ast.Node
15041504
if !hasCollision {
@@ -1508,13 +1508,13 @@ func (a *Analyzer) parseCaseStar(form interface{}, env Env) (*ast.Node, error) {
15081508
return nil, err
15091509
}
15101510
}
1511-
1511+
15121512
// Analyze the result expression (or condp for collisions)
15131513
resultExprNode, err := a.analyzeForm(resultExpr, env)
15141514
if err != nil {
15151515
return nil, err
15161516
}
1517-
1517+
15181518
entries = append(entries, ast.CaseEntry{
15191519
Key: keyInt,
15201520
TestConstant: testConstantNode,

pkg/lang/apersistentset.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ type (
77
AFn
88
IPersistentSet
99
IHashEq
10+
Hasher
1011
}
1112
)
1213

@@ -28,6 +29,21 @@ func apersistentsetEquiv(a APersistentSet, o any) bool {
2829
return true
2930
}
3031

32+
func apersistentsetHash(hc *uint32, a APersistentSet) uint32 {
33+
if *hc != 0 {
34+
return *hc
35+
}
36+
37+
// Following Clojure's APersistentSet.hashCode logic:
38+
// Sum of hash values of all elements
39+
var hash uint32 = 0
40+
for seq := a.Seq(); seq != nil; seq = seq.Next() {
41+
hash += Hash(seq.First())
42+
}
43+
*hc = hash
44+
return hash
45+
}
46+
3147
func apersistentsetHashEq(hc *uint32, a APersistentSet) uint32 {
3248
if *hc != 0 {
3349
return *hc

pkg/lang/set.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ func (s *Set) Equiv(o any) bool {
153153
return apersistentsetEquiv(s, o)
154154
}
155155

156+
func (s *Set) Hash() uint32 {
157+
return apersistentsetHash(&s.hash, s)
158+
}
159+
156160
func (s *Set) HashEq() uint32 {
157161
return apersistentsetHashEq(&s.hasheq, s)
158162
}

pkg/runtime/evalast.go

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ func (env *environment) EvalASTCase(n *ast.Node) (interface{}, error) {
470470
// Determine the lookup key based on test type
471471
var lookupKey int64
472472
testType := caseNode.TestType.(lang.Keyword)
473-
473+
474474
switch testType {
475475
case lang.KWInt:
476476
// For integer test type, use the value directly
@@ -493,7 +493,7 @@ func (env *environment) EvalASTCase(n *ast.Node) (interface{}, error) {
493493
if caseNode.Mask != 0 {
494494
lookupKey = int64(uint32(lookupKey>>uint(caseNode.Shift)) & uint32(caseNode.Mask))
495495
}
496-
496+
497497
case lang.KWHashIdentity:
498498
// Use identity hash for keywords
499499
hash := lang.IdentityHash(testVal)
@@ -503,7 +503,7 @@ func (env *environment) EvalASTCase(n *ast.Node) (interface{}, error) {
503503
} else {
504504
lookupKey = int64(uint32(hash>>uint(caseNode.Shift)) & uint32(caseNode.Mask))
505505
}
506-
506+
507507
case lang.KWHashEquiv:
508508
// Use hash for general values
509509
hash := lang.Hash(testVal)
@@ -513,7 +513,7 @@ func (env *environment) EvalASTCase(n *ast.Node) (interface{}, error) {
513513
} else {
514514
lookupKey = int64(uint32(hash>>uint(caseNode.Shift)) & uint32(caseNode.Mask))
515515
}
516-
516+
517517
default:
518518
return nil, fmt.Errorf("unknown test type: %v", testType)
519519
}
@@ -522,7 +522,7 @@ func (env *environment) EvalASTCase(n *ast.Node) (interface{}, error) {
522522
// Following Clojure's implementation: find the entry whose key matches lookupKey
523523
// If that entry is marked as collision, evaluate result directly (it's a condp)
524524
// Otherwise, verify the test value matches before evaluating result
525-
525+
526526
for _, entry := range caseNode.Entries {
527527
if entry.Key != lookupKey {
528528
continue
@@ -538,54 +538,54 @@ func (env *environment) EvalASTCase(n *ast.Node) (interface{}, error) {
538538
} else {
539539
// Non-collision case, need to verify the actual value matches
540540
if testType == lang.KWInt {
541-
// For integers with shift/mask, we need to verify the actual value
542-
// because multiple values can map to the same key
543-
if caseNode.Mask != 0 {
544-
// Need to check actual value matches
545-
expectedVal, err := env.EvalAST(entry.TestConstant)
546-
if err != nil {
547-
return nil, err
548-
}
549-
if lang.Equals(testVal, expectedVal) {
550-
result, err := env.EvalAST(entry.ResultExpr)
551-
if err != nil {
552-
return nil, err
553-
}
554-
return result, nil
555-
}
556-
} else {
557-
// For integers without shift/mask, the key match is sufficient
541+
// For integers with shift/mask, we need to verify the actual value
542+
// because multiple values can map to the same key
543+
if caseNode.Mask != 0 {
544+
// Need to check actual value matches
545+
expectedVal, err := env.EvalAST(entry.TestConstant)
546+
if err != nil {
547+
return nil, err
548+
}
549+
if lang.Equals(testVal, expectedVal) {
558550
result, err := env.EvalAST(entry.ResultExpr)
559551
if err != nil {
560552
return nil, err
561553
}
562554
return result, nil
563555
}
564556
} else {
565-
// For hash-based dispatch, verify the actual value matches
566-
expectedVal, err := env.EvalAST(entry.TestConstant)
557+
// For integers without shift/mask, the key match is sufficient
558+
result, err := env.EvalAST(entry.ResultExpr)
567559
if err != nil {
568560
return nil, err
569561
}
570-
571-
// Use appropriate comparison based on test type
572-
var matches bool
573-
if testType == lang.KWHashIdentity {
574-
matches = testVal == expectedVal
575-
} else {
576-
matches = lang.Equals(testVal, expectedVal)
562+
return result, nil
563+
}
564+
} else {
565+
// For hash-based dispatch, verify the actual value matches
566+
expectedVal, err := env.EvalAST(entry.TestConstant)
567+
if err != nil {
568+
return nil, err
569+
}
570+
571+
// Use appropriate comparison based on test type
572+
var matches bool
573+
if testType == lang.KWHashIdentity {
574+
matches = testVal == expectedVal
575+
} else {
576+
matches = lang.Equals(testVal, expectedVal)
577+
}
578+
if matches {
579+
result, err := env.EvalAST(entry.ResultExpr)
580+
if err != nil {
581+
return nil, err
577582
}
578-
if matches {
579-
result, err := env.EvalAST(entry.ResultExpr)
580-
if err != nil {
581-
return nil, err
582-
}
583-
return result, nil
583+
return result, nil
584584
}
585585
}
586586
}
587587
}
588-
588+
589589
// No match found, evaluate default
590590
return env.EvalAST(caseNode.Default)
591591
}

test/glojure/test_glojure/case.glj

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
;;
33
;; This file tests all the different paths through case evaluation:
44
;; - Compact integer dispatch
5-
;; - Sparse integer dispatch
5+
;; - Sparse integer dispatch
66
;; - Hash-based dispatch with keywords (identity)
77
;; - Hash-based dispatch with strings and other types (equiv)
88
;; - Default expressions
@@ -21,7 +21,7 @@
2121
(when (and (seq? case*-form#) (= 'case* (first case*-form#)))
2222
(let [switch-type# (nth case*-form# 6)
2323
test-type# (nth case*-form# 7)]
24-
(is (= ~expected-switch switch-type#)
24+
(is (= ~expected-switch switch-type#)
2525
(str "Expected switch type " ~expected-switch " but got " switch-type#))
2626
(is (= ~expected-type test-type#)
2727
(str "Expected test type " ~expected-type " but got " test-type#))))))
@@ -140,7 +140,7 @@
140140
(testing "nil as test value"
141141
(is (= :nil (case nil nil :nil :not-nil)))
142142
(is (= :not-nil (case 1 nil :nil :not-nil))))
143-
143+
144144
(testing "nil as result"
145145
(is (nil? (case 1 1 nil 2 :two)))
146146
(is (nil? (case :foo :foo nil :bar false)))))
@@ -151,10 +151,10 @@
151151
(is (= :true (case true true :true false :false)))
152152
(is (= :false (case false true :true false :false)))
153153
(is (= :default (case nil true :true false :false :default))))
154-
154+
155155
(testing "Boolean false vs nil distinction"
156156
;; This is a specific test case that was failing
157-
(is (= :boolean-false-result
157+
(is (= :boolean-false-result
158158
(case false
159159
false :boolean-false-result
160160
nil :nil-result
@@ -173,8 +173,11 @@
173173
(is (= :pair (case [1 2] [1 2] :pair [2 1] :reversed)))
174174
(is (= :map (case {:a :map :of :kws}
175175
{:a :map :of :kws} :map
176+
:other)))
177+
(is (= :set (case #{1 2 3}
178+
#{1 2 3} :set
176179
:other))))
177-
180+
178181
(testing "Lists must use vectors as test constants"
179182
;; Lists can't be written directly as test constants because parentheses group multiple constants
180183
;; But a vector can be used to match a list
@@ -197,7 +200,7 @@
197200
1 :one
198201
2 :two)
199202
(is (= 1 @counter) "Test expression should be evaluated exactly once")))
200-
203+
201204
(testing "Result expressions not evaluated until matched"
202205
(let [side-effect (atom [])]
203206
(case 2

0 commit comments

Comments
 (0)