diff --git a/src/SortingAlgorithms.jl b/src/SortingAlgorithms.jl index e3df2f7..6b08cd1 100644 --- a/src/SortingAlgorithms.jl +++ b/src/SortingAlgorithms.jl @@ -16,12 +16,12 @@ struct TimSortAlg <: Algorithm end struct RadixSortAlg <: Algorithm end struct CombSortAlg <: Algorithm end -function maybe_optimize(x::Algorithm) +function maybe_optimize(x::Algorithm) isdefined(Base.Sort, :InitialOptimizations) ? Base.Sort.InitialOptimizations(x) : x -end +end const HeapSort = maybe_optimize(HeapSortAlg()) const TimSort = maybe_optimize(TimSortAlg()) -# Whenever InitialOptimizations is defined, RadixSort falls +# Whenever InitialOptimizations is defined, RadixSort falls # back to Base.DEFAULT_STABLE which already includes them. const RadixSort = RadixSortAlg() @@ -79,6 +79,27 @@ end # # Original author: @kmsquire +@static if v"1.9.0-alpha" <= VERSION <= v"1.9.1" + function Base.getindex(v::Base.Sort.WithoutMissingVector, i::UnitRange) + out = Vector{eltype(v)}(undef, length(i)) + out .= v.data[i] + out + end + + # skip MissingOptimization due to JuliaLang/julia#50171 + const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE.next + + # Explicitly define conversion from _sort!(v, alg, order, kw) to sort!(v, lo, hi, alg, order) + # To avoid excessively strict dispatch loop detection + function Base.Sort._sort!(v::AbstractVector, a::Union{HeapSortAlg, TimSortAlg, RadixSortAlg, CombSortAlg}, o::Base.Order.Ordering, kw) + Base.Sort.@getkw lo hi scratch + sort!(v, lo, hi, a, o) + scratch + end +else + const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE +end + const Run = UnitRange{Int} const MIN_GALLOP = 7 @@ -490,7 +511,7 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, ::TimSortAlg, o::Ordering) # Make a run of length minrun count = min(minrun, hi-i+1) run_range = i:i+count-1 - sort!(v, i, i+count-1, DEFAULT_STABLE, o) + sort!(v, i, i+count-1, _FIVE_ARG_SAFE_DEFAULT_STABLE, o) else if !issorted(run_range) run_range = last(run_range):first(run_range) diff --git a/test/runtests.jl b/test/runtests.jl index 5738105..f8d9970 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,8 +4,9 @@ using StatsBase using Random a = rand(1:10000, 1000) +am = [rand() < .9 ? i : missing for i in a] -for alg in [TimSort, HeapSort, RadixSort, CombSort] +for alg in [TimSort, HeapSort, RadixSort, CombSort, SortingAlgorithms.TimSortAlg()] b = sort(a, alg=alg) @test issorted(b) ix = sortperm(a, alg=alg) @@ -13,6 +14,9 @@ for alg in [TimSort, HeapSort, RadixSort, CombSort] @test issorted(b) @test a[ix] == b + # legacy 3-argument calling convention + @test b == sort!(copy(a), alg, Base.Order.Forward) + b = sort(a, alg=alg, rev=true) @test issorted(b, rev=true) ix = sortperm(a, alg=alg, rev=true) @@ -34,9 +38,26 @@ for alg in [TimSort, HeapSort, RadixSort, CombSort] invpermute!(c, ix) @test c == a - if alg != RadixSort # RadixSort does not work with Lt orderings + if alg != RadixSort # RadixSort does not work with Lt orderings or missing c = sort(a, alg=alg, lt=(>)) @test b == c + + # Issue https://github.com/JuliaData/DataFrames.jl/issues/3340 + bm1 = sort(am, alg=alg) + @test issorted(bm1) + @test count(ismissing, bm1) == count(ismissing, am) + + bm2 = am[sortperm(am, alg=alg)] + @test issorted(bm2) + @test count(ismissing, bm2) == count(ismissing, am) + + bm3 = am[sortperm!(collect(eachindex(am)), am, alg=alg)] + @test issorted(bm3) + @test count(ismissing, bm3) == count(ismissing, am) + + if alg == TimSort # Stable + @test all(bm1 .=== bm2 .=== bm3) + end end c = sort(a, alg=alg, by=x->1/x) @@ -103,8 +124,8 @@ for n in [0:10..., 100, 101, 1000, 1001] # test float sorting with NaNs s = sort(v, alg=alg, order=ord) @test issorted(s, order=ord) - - # This tests that NaNs (which compare equivalent) are treated stably + + # This tests that NaNs (which compare equivalent) are treated stably # even when the underlying algorithm is unstable. That it happens to # pass is not a part of the public API: @test reinterpret(UInt64, v[map(isnan, v)]) == reinterpret(UInt64, s[map(isnan, s)])