@@ -336,7 +336,8 @@ julia> one(s)
336336
337337```
338338"""
339- struct TreeConfigEnumerator{N,S,C}
339+ # it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
340+ mutable struct TreeConfigEnumerator{N,S,C}
340341 tag:: TreeTag
341342 data:: StaticElementVector{N,S,C}
342343 left:: TreeConfigEnumerator{N,S,C}
@@ -379,22 +380,37 @@ function printnode(io::IO, t::TreeConfigEnumerator)
379380 end
380381end
381382
382- function Base. length (x:: TreeConfigEnumerator )
383+ Base. length (x:: TreeConfigEnumerator ) = _length (x, IdDict {typeof(x), Int} ())
384+
385+ function _length (x, d)
386+ haskey (d, x) && return d[x]
383387 if x. tag === SUM
384- return length (x. left) + length (x. right)
388+ l = _length (x. left, d) + _length (x. right, d)
389+ d[x] = l
390+ return l
385391 elseif x. tag === PROD
386- return length (x. left) * length (x. right)
392+ l = _length (x. left, d) * _length (x. right, d)
393+ d[x] = l
394+ return l
387395 elseif x. tag === ZERO
388396 return 0
389397 else
390398 return 1
391399 end
392400end
393401
394- function num_nodes (x:: TreeConfigEnumerator )
395- x. tag == ZERO && return 1
396- x. tag == LEAF && return 1
397- return num_nodes (x. left) + num_nodes (x. right) + 1
402+ num_nodes (x:: TreeConfigEnumerator ) = _num_nodes (x, IdDict {typeof(x), Int} ())
403+ function _num_nodes (x, d)
404+ haskey (d, x) && return 0
405+ if x. tag == ZERO
406+ res = 1
407+ elseif x. tag == LEAF
408+ res = 1
409+ else
410+ res = _num_nodes (x. left, d) + _num_nodes (x. right, d) + 1
411+ end
412+ d[x] = res
413+ return res
398414end
399415
400416function Base.:(== )(x:: TreeConfigEnumerator{N,S,C} , y:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
0 commit comments