Skip to content

Commit 7c3d86a

Browse files
authored
Simplify show implementation (#46)
1 parent e4c91aa commit 7c3d86a

File tree

3 files changed

+25
-56
lines changed

3 files changed

+25
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.25"
4+
version = "0.2.26"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Adapt: Adapt, WrappedArray
1+
using Adapt: Adapt, WrappedArray, adapt
22
using ArrayLayouts: zero!
33
using BlockArrays:
44
BlockArrays,
@@ -337,60 +337,29 @@ function Base.Array(a::AnyAbstractBlockSparseArray)
337337
return Array{eltype(a)}(a)
338338
end
339339

340-
using SparseArraysBase: ReplacedUnstoredSparseArray
341-
342-
# Wraps a block sparse array but replaces the unstored values.
343-
# This is used in printing in order to customize printing
344-
# of zero/unstored values.
345-
struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <:
346-
AbstractBlockSparseArray{T,N}
347-
parent::Parent
348-
getunstoredblock::F
349-
end
350-
Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent
351-
Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a))
352-
function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray)
353-
return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock)
354-
end
355-
356-
# This is copied from `SparseArraysBase.jl` since it is not part
357-
# of the public interface.
358-
# Like `Char` but prints without quotes.
359-
struct UnquotedChar <: AbstractChar
360-
char::Char
340+
function SparseArraysBase.isstored(
341+
A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}
342+
) where {N}
343+
bI = BlockIndex(findblockindex.(axes(A), I))
344+
bA = blocks(A)
345+
return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...)
361346
end
362-
Base.show(io::IO, c::UnquotedChar) = print(io, c.char)
363-
Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c)
364347

365-
using FillArrays: Fill
366-
struct GetUnstoredBlockShow{Axes}
367-
axes::Axes
368-
end
369-
@inline function (f::GetUnstoredBlockShow)(
370-
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
371-
) where {N}
372-
# TODO: Make sure this works for sparse or block sparse blocks, immutable
373-
# blocks, diagonal blocks, etc.!
374-
b_size = ntuple(ndims(a)) do d
375-
return length(f.axes[d][Block(I[d])])
348+
function Base.replace_in_print_matrix(
349+
A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString
350+
)
351+
return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s)
352+
end
353+
354+
# attempt to catch things that wrap GPU arrays
355+
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
356+
X_cpu = adapt(Array, X)
357+
if typeof(X_cpu) === typeof(X) # prevent infinite recursion
358+
# need to specify ndims to allow specialized code for vector/matrix
359+
@allowscalar @invoke Base.print_array(
360+
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
361+
)
362+
else
363+
Base.print_array(io, X_cpu)
376364
end
377-
return Fill(UnquotedChar('.'), b_size)
378-
end
379-
# TODO: Use `Base.to_indices`.
380-
@inline function (f::GetUnstoredBlockShow)(
381-
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
382-
) where {N}
383-
return f(a, Tuple(I)...)
384-
end
385-
386-
# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function
387-
# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`.
388-
function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray)
389-
summary(io, a)
390-
isempty(a) && return nothing
391-
print(io, ":")
392-
println(io)
393-
a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a)))
394-
@allowscalar Base.print_array(io, a′)
395-
return nothing
396365
end

test/test_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ arrayts = (Array, JLArray)
11391139
a = BlockSparseMatrix{elt,arrayt{elt,2}}([2, 2], [2, 2])
11401140
@allowscalar a[1, 2] = 12
11411141
@test sprint(show, "text/plain", a) ==
1142-
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)). .\n $(zero(eltype(a))) $(zero(eltype(a))). .\n ───────────┼──────\n . .. .\n . .. ."
1142+
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ \n ⋅ ⋅ "
11431143
end
11441144
end
11451145
@testset "TypeParameterAccessors.position" begin

0 commit comments

Comments
 (0)