From a1d688f64b9dc2f76969628c49ddb7881417d0da Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 17 Dec 2025 08:18:12 -0500 Subject: [PATCH] Combine FusionStyles --- .github/workflows/IntegrationTest.yml | 1 + Project.toml | 2 +- src/contract/contract.jl | 5 +---- src/matricize.jl | 7 +++---- test/test_fusionstyle.jl | 28 +++++++++++++++++++++++++++ 5 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 test/test_fusionstyle.jl diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 17d2d40..bc64b44 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -21,6 +21,7 @@ jobs: - 'FusionTensors' - 'GradedArrays' - 'ITensorBase' + - 'ITensorNetworksNext' - 'KroneckerArrays' - 'NamedDimsArrays' uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main" diff --git a/Project.toml b/Project.toml index 24c36ed..0341623 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.6.5" +version = "0.6.6" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/contract/contract.jl b/src/contract/contract.jl index a77eecc..92b5cdb 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -10,10 +10,7 @@ end Matricize() = Matricize(ReshapeFusion()) function default_contract_algorithm(A1::Type{<:AbstractArray}, A2::Type{<:AbstractArray}) - style1 = FusionStyle(A1) - style2 = FusionStyle(A2) - style1 == style2 || error("Styles must match.") - return Matricize(style1) + return Matricize(FusionStyle(FusionStyle(A1), FusionStyle(A2))) end # Required interface if not using diff --git a/src/matricize.jl b/src/matricize.jl index 01e9f66..48d476b 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -6,6 +6,8 @@ abstract type FusionStyle end FusionStyle(x) = FusionStyle(typeof(x)) FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) +FusionStyle(style1::Style, style2::Style) where {Style <: FusionStyle} = Style() +FusionStyle(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion() # ======================================= misc ======================================== function trivial_axis( @@ -72,10 +74,7 @@ function tensor_product_axis(r1::AbstractUnitRange, r2::AbstractUnitRange) return tensor_product_axis(style, r1, r2) end function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange) - style1 = FusionStyle(r1) - style2 = FusionStyle(r2) - style1 == style2 || error("Styles must match.") - return style1 + return FusionStyle(FusionStyle(r1), FusionStyle(r2)) end function fused_axis( diff --git a/test/test_fusionstyle.jl b/test/test_fusionstyle.jl new file mode 100644 index 0000000..a37a812 --- /dev/null +++ b/test/test_fusionstyle.jl @@ -0,0 +1,28 @@ +using TensorAlgebra: TensorAlgebra as TA, FusionStyle, Matricize, ReshapeFusion +using Test: @test, @testset + +module FusionStyleTestUtils + using TensorAlgebra: TensorAlgebra as TA + struct MyArray{T, N, A <: AbstractArray{T, N}} <: AbstractArray{T, N} + parent::A + end + struct MyArrayFusion <: TA.FusionStyle end + TA.FusionStyle(::Type{<:MyArray}) = MyArrayFusion() +end +using .FusionStyleTestUtils: MyArray, MyArrayFusion + +@testset "FusionStyle" begin + a1 = randn(2, 2) + a2 = MyArray(randn(2, 2)) + @test FusionStyle(a1) ≡ ReshapeFusion() + @test FusionStyle(a2) ≡ MyArrayFusion() + @test FusionStyle(typeof(a1)) ≡ ReshapeFusion() + @test FusionStyle(ReshapeFusion(), ReshapeFusion()) ≡ ReshapeFusion() + @test FusionStyle(MyArrayFusion(), MyArrayFusion()) ≡ MyArrayFusion() + @test FusionStyle(MyArrayFusion(), ReshapeFusion()) ≡ ReshapeFusion() + @test FusionStyle(ReshapeFusion(), MyArrayFusion()) ≡ ReshapeFusion() + @test TA.default_contract_algorithm(typeof(a1), typeof(a1)) ≡ Matricize(ReshapeFusion()) + @test TA.default_contract_algorithm(typeof(a1), typeof(a2)) ≡ Matricize(ReshapeFusion()) + @test TA.default_contract_algorithm(typeof(a2), typeof(a1)) ≡ Matricize(ReshapeFusion()) + @test TA.default_contract_algorithm(typeof(a2), typeof(a2)) ≡ Matricize(MyArrayFusion()) +end