|
1 |
| -using Adapt: Adapt, WrappedArray |
| 1 | +using Adapt: Adapt, WrappedArray, adapt |
2 | 2 | using ArrayLayouts: zero!
|
3 | 3 | using BlockArrays:
|
4 | 4 | BlockArrays,
|
@@ -337,60 +337,29 @@ function Base.Array(a::AnyAbstractBlockSparseArray)
|
337 | 337 | return Array{eltype(a)}(a)
|
338 | 338 | end
|
339 | 339 |
|
340 |
| -using SparseArraysBase: ReplacedUnstoredSparseArray |
341 |
| - |
342 |
| -# Wraps a block sparse array but replaces the unstored values. |
343 |
| -# This is used in printing in order to customize printing |
344 |
| -# of zero/unstored values. |
345 |
| -struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <: |
346 |
| - AbstractBlockSparseArray{T,N} |
347 |
| - parent::Parent |
348 |
| - getunstoredblock::F |
349 |
| -end |
350 |
| -Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent |
351 |
| -Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a)) |
352 |
| -function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray) |
353 |
| - return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock) |
354 |
| -end |
355 |
| - |
356 |
| -# This is copied from `SparseArraysBase.jl` since it is not part |
357 |
| -# of the public interface. |
358 |
| -# Like `Char` but prints without quotes. |
359 |
| -struct UnquotedChar <: AbstractChar |
360 |
| - char::Char |
| 340 | +function SparseArraysBase.isstored( |
| 341 | + A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N} |
| 342 | +) where {N} |
| 343 | + bI = BlockIndex(findblockindex.(axes(A), I)) |
| 344 | + bA = blocks(A) |
| 345 | + return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...) |
361 | 346 | end
|
362 |
| -Base.show(io::IO, c::UnquotedChar) = print(io, c.char) |
363 |
| -Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c) |
364 | 347 |
|
365 |
| -using FillArrays: Fill |
366 |
| -struct GetUnstoredBlockShow{Axes} |
367 |
| - axes::Axes |
368 |
| -end |
369 |
| -@inline function (f::GetUnstoredBlockShow)( |
370 |
| - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} |
371 |
| -) where {N} |
372 |
| - # TODO: Make sure this works for sparse or block sparse blocks, immutable |
373 |
| - # blocks, diagonal blocks, etc.! |
374 |
| - b_size = ntuple(ndims(a)) do d |
375 |
| - return length(f.axes[d][Block(I[d])]) |
| 348 | +function Base.replace_in_print_matrix( |
| 349 | + A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString |
| 350 | +) |
| 351 | + return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s) |
| 352 | +end |
| 353 | + |
| 354 | +# attempt to catch things that wrap GPU arrays |
| 355 | +function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray) |
| 356 | + X_cpu = adapt(Array, X) |
| 357 | + if typeof(X_cpu) === typeof(X) # prevent infinite recursion |
| 358 | + # need to specify ndims to allow specialized code for vector/matrix |
| 359 | + @allowscalar @invoke Base.print_array( |
| 360 | + io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)} |
| 361 | + ) |
| 362 | + else |
| 363 | + Base.print_array(io, X_cpu) |
376 | 364 | end
|
377 |
| - return Fill(UnquotedChar('.'), b_size) |
378 |
| -end |
379 |
| -# TODO: Use `Base.to_indices`. |
380 |
| -@inline function (f::GetUnstoredBlockShow)( |
381 |
| - a::AbstractArray{<:Any,N}, I::CartesianIndex{N} |
382 |
| -) where {N} |
383 |
| - return f(a, Tuple(I)...) |
384 |
| -end |
385 |
| - |
386 |
| -# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function |
387 |
| -# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`. |
388 |
| -function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray) |
389 |
| - summary(io, a) |
390 |
| - isempty(a) && return nothing |
391 |
| - print(io, ":") |
392 |
| - println(io) |
393 |
| - a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a))) |
394 |
| - @allowscalar Base.print_array(io, a′) |
395 |
| - return nothing |
396 | 365 | end
|
0 commit comments