Skip to content

Commit f542c6e

Browse files
authored
Combine FusionStyles (#107)
1 parent 28e60c0 commit f542c6e

File tree

5 files changed

+34
-9
lines changed

5 files changed

+34
-9
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
- 'FusionTensors'
2222
- 'GradedArrays'
2323
- 'ITensorBase'
24+
- 'ITensorNetworksNext'
2425
- 'KroneckerArrays'
2526
- 'NamedDimsArrays'
2627
uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/contract/contract.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@ end
1010
Matricize() = Matricize(ReshapeFusion())
1111

1212
function default_contract_algorithm(A1::Type{<:AbstractArray}, A2::Type{<:AbstractArray})
13-
style1 = FusionStyle(A1)
14-
style2 = FusionStyle(A2)
15-
style1 == style2 || error("Styles must match.")
16-
return Matricize(style1)
13+
return Matricize(FusionStyle(FusionStyle(A1), FusionStyle(A2)))
1714
end
1815

1916
# Required interface if not using

src/matricize.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ abstract type FusionStyle end
66

77
FusionStyle(x) = FusionStyle(typeof(x))
88
FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,)))
9+
FusionStyle(style1::Style, style2::Style) where {Style <: FusionStyle} = Style()
10+
FusionStyle(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion()
911

1012
# ======================================= misc ========================================
1113
function trivial_axis(
@@ -72,10 +74,7 @@ function tensor_product_axis(r1::AbstractUnitRange, r2::AbstractUnitRange)
7274
return tensor_product_axis(style, r1, r2)
7375
end
7476
function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange)
75-
style1 = FusionStyle(r1)
76-
style2 = FusionStyle(r2)
77-
style1 == style2 || error("Styles must match.")
78-
return style1
77+
return FusionStyle(FusionStyle(r1), FusionStyle(r2))
7978
end
8079

8180
function fused_axis(

test/test_fusionstyle.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using TensorAlgebra: TensorAlgebra as TA, FusionStyle, Matricize, ReshapeFusion
2+
using Test: @test, @testset
3+
4+
module FusionStyleTestUtils
5+
using TensorAlgebra: TensorAlgebra as TA
6+
struct MyArray{T, N, A <: AbstractArray{T, N}} <: AbstractArray{T, N}
7+
parent::A
8+
end
9+
struct MyArrayFusion <: TA.FusionStyle end
10+
TA.FusionStyle(::Type{<:MyArray}) = MyArrayFusion()
11+
end
12+
using .FusionStyleTestUtils: MyArray, MyArrayFusion
13+
14+
@testset "FusionStyle" begin
15+
a1 = randn(2, 2)
16+
a2 = MyArray(randn(2, 2))
17+
@test FusionStyle(a1) ReshapeFusion()
18+
@test FusionStyle(a2) MyArrayFusion()
19+
@test FusionStyle(typeof(a1)) ReshapeFusion()
20+
@test FusionStyle(ReshapeFusion(), ReshapeFusion()) ReshapeFusion()
21+
@test FusionStyle(MyArrayFusion(), MyArrayFusion()) MyArrayFusion()
22+
@test FusionStyle(MyArrayFusion(), ReshapeFusion()) ReshapeFusion()
23+
@test FusionStyle(ReshapeFusion(), MyArrayFusion()) ReshapeFusion()
24+
@test TA.default_contract_algorithm(typeof(a1), typeof(a1)) Matricize(ReshapeFusion())
25+
@test TA.default_contract_algorithm(typeof(a1), typeof(a2)) Matricize(ReshapeFusion())
26+
@test TA.default_contract_algorithm(typeof(a2), typeof(a1)) Matricize(ReshapeFusion())
27+
@test TA.default_contract_algorithm(typeof(a2), typeof(a2)) Matricize(MyArrayFusion())
28+
end

0 commit comments

Comments
 (0)