Skip to content

Commit 7ea8fd0

Browse files
committed
update
1 parent 730ece6 commit 7ea8fd0

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

src/arithematics.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ julia> one(s)
336336
337337
```
338338
"""
339-
struct TreeConfigEnumerator{N,S,C}
339+
# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
340+
mutable struct TreeConfigEnumerator{N,S,C}
340341
tag::TreeTag
341342
data::StaticElementVector{N,S,C}
342343
left::TreeConfigEnumerator{N,S,C}
@@ -379,22 +380,37 @@ function printnode(io::IO, t::TreeConfigEnumerator)
379380
end
380381
end
381382

382-
function Base.length(x::TreeConfigEnumerator)
383+
Base.length(x::TreeConfigEnumerator) = _length(x, IdDict{typeof(x), Int}())
384+
385+
function _length(x, d)
386+
haskey(d, x) && return d[x]
383387
if x.tag === SUM
384-
return length(x.left) + length(x.right)
388+
l = _length(x.left, d) + _length(x.right, d)
389+
d[x] = l
390+
return l
385391
elseif x.tag === PROD
386-
return length(x.left) * length(x.right)
392+
l = _length(x.left, d) * _length(x.right, d)
393+
d[x] = l
394+
return l
387395
elseif x.tag === ZERO
388396
return 0
389397
else
390398
return 1
391399
end
392400
end
393401

394-
function num_nodes(x::TreeConfigEnumerator)
395-
x.tag == ZERO && return 1
396-
x.tag == LEAF && return 1
397-
return num_nodes(x.left) + num_nodes(x.right) + 1
402+
num_nodes(x::TreeConfigEnumerator) = _num_nodes(x, IdDict{typeof(x), Int}())
403+
function _num_nodes(x, d)
404+
haskey(d, x) && return 0
405+
if x.tag == ZERO
406+
res = 1
407+
elseif x.tag == LEAF
408+
res = 1
409+
else
410+
res = _num_nodes(x.left, d) + _num_nodes(x.right, d) + 1
411+
end
412+
d[x] = res
413+
return res
398414
end
399415

400416
function Base.:(==)(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}

0 commit comments

Comments
 (0)