Skip to content

Commit ec7b123

Browse files
authored
Add scalar keyword argument to flatten (#3283)
1 parent 436b686 commit ec7b123

File tree

3 files changed

+131
-15
lines changed

3 files changed

+131
-15
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
* Add `haskey` and `get` methods to `DataFrameColumns`
2727
to make it support dictionary interface more completely
2828
([#3282](https://github.com/JuliaData/DataFrames.jl/pull/3282))
29+
* Allow passing `scalar` keyword argument in `flatten`
30+
([#3283](https://github.com/JuliaData/DataFrames.jl/pull/3283))
2931

3032
## Bug fixes
3133

src/abstractdataframe/abstractdataframe.jl

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,7 @@ function Missings.allowmissing(df::AbstractDataFrame,
22592259
end
22602260

22612261
"""
2262-
flatten(df::AbstractDataFrame, cols)
2262+
flatten(df::AbstractDataFrame, cols; scalar::Type=Union{})
22632263
22642264
When columns `cols` of data frame `df` have iterable elements that define
22652265
`length` (for example a `Vector` of `Vector`s), return a `DataFrame` where each
@@ -2273,6 +2273,11 @@ returned `DataFrame` will affect `df`.
22732273
22742274
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR).
22752275
2276+
If `scalar` is passed then values that have this type in flattened columns
2277+
are treated as scalars and broadcasted as many times as is needed to match
2278+
lengths of values stored in other columns. If all values in a row are scalars,
2279+
a single row is produced.
2280+
22762281
$METADATA_FIXED
22772282
22782283
# Examples
@@ -2334,10 +2339,33 @@ julia> df3 = DataFrame(a=[1, 2], b=[[1, 2], [3, 4]], c=[[5, 6], [7]])
23342339
23352340
julia> flatten(df3, [:b, :c])
23362341
ERROR: ArgumentError: Lengths of iterables stored in columns :b and :c are not the same in row 2
2342+
2343+
julia> df4 = DataFrame(a=[1, 2, 3],
2344+
b=[[1, 2], missing, missing],
2345+
c=[[5, 6], missing, [7, 8]])
2346+
3×3 DataFrame
2347+
Row │ a b c
2348+
│ Int64 Array…? Array…?
2349+
─────┼─────────────────────────
2350+
1 │ 1 [1, 2] [5, 6]
2351+
2 │ 2 missing missing
2352+
3 │ 3 missing [7, 8]
2353+
2354+
julia> flatten(df4, [:b, :c], scalar=Missing)
2355+
5×3 DataFrame
2356+
Row │ a b c
2357+
│ Int64 Int64? Int64?
2358+
─────┼─────────────────────────
2359+
1 │ 1 1 5
2360+
2 │ 1 2 6
2361+
3 │ 2 missing missing
2362+
4 │ 3 missing 7
2363+
5 │ 3 missing 8
23372364
```
23382365
"""
23392366
function flatten(df::AbstractDataFrame,
2340-
cols::Union{ColumnIndex, MultiColumnIndex})
2367+
cols::Union{ColumnIndex, MultiColumnIndex};
2368+
scalar::Type=Union{})
23412369
_check_consistency(df)
23422370

23432371
idxcols = index(df)[cols]
@@ -2348,15 +2376,16 @@ function flatten(df::AbstractDataFrame,
23482376
end
23492377

23502378
col1 = first(idxcols)
2351-
lengths = length.(df[!, col1])
2352-
for col in idxcols
2353-
v = df[!, col]
2354-
if any(x -> length(x[1]) != x[2], zip(v, lengths))
2355-
r = findfirst(x -> x != 0, length.(v) .- lengths)
2356-
colnames = _names(df)
2357-
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
2358-
"and :$(colnames[col]) are not the same in row $r"))
2359-
end
2379+
lengths = Int[x isa scalar ? -1 : length(x) for x in df[!, col1]]
2380+
for (i, coli) in enumerate(idxcols)
2381+
i == 1 && continue
2382+
update_lengths!(lengths, df[!, coli], scalar, df, col1, coli)
2383+
end
2384+
2385+
# handle case where in all columns we had a scalar
2386+
# in this case we keep it one time
2387+
for i in 1:length(lengths)
2388+
lengths[i] == -1 && (lengths[i] = 1)
23602389
end
23612390

23622391
new_df = similar(df[!, Not(cols)], sum(lengths))
@@ -2366,18 +2395,38 @@ function flatten(df::AbstractDataFrame,
23662395
length(idxcols) > 1 && sort!(idxcols)
23672396
for col in idxcols
23682397
col_to_flatten = df[!, col]
2369-
fast_path = eltype(col_to_flatten) isa AbstractVector &&
2398+
fast_path = eltype(col_to_flatten) <: AbstractVector &&
23702399
!isempty(col_to_flatten)
2371-
flattened_col = fast_path ?
2372-
reduce(vcat, col_to_flatten) :
2373-
collect(Iterators.flatten(col_to_flatten))
2400+
flattened_col = if fast_path
2401+
reduce(vcat, col_to_flatten)
2402+
elseif scalar === Union{}
2403+
collect(Iterators.flatten(col_to_flatten))
2404+
else
2405+
collect(Iterators.flatten(v isa scalar ? Iterators.repeated(v, l) : v
2406+
for (l, v) in zip(lengths, col_to_flatten)))
2407+
end
23742408
insertcols!(new_df, col, _names(df)[col] => flattened_col)
23752409
end
23762410

23772411
_copy_all_note_metadata!(new_df, df)
23782412
return new_df
23792413
end
23802414

2415+
function update_lengths!(lengths::Vector{Int}, col::AbstractVector, scalar::Type,
2416+
df::AbstractDataFrame, col1::Integer, coli::Integer)
2417+
for (i, v) in enumerate(col)
2418+
v isa scalar && continue
2419+
lv = length(v)
2420+
if lengths[i] == -1
2421+
lengths[i] = lv
2422+
elseif lengths[i] != lv
2423+
colnames = _names(df)
2424+
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
2425+
"and :$(colnames[coli]) are not the same in row $i"))
2426+
end
2427+
end
2428+
end
2429+
23812430
function repeat_lengths!(longnew::AbstractVector, shortold::AbstractVector,
23822431
lengths::AbstractVector{Int})
23832432
counter = 1

test/reshape.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,71 @@ end
431431
@test flatten(DataFrame(), All()) == DataFrame()
432432
end
433433

434+
@testset "flatten with scalar" begin
435+
df = DataFrame(a=[1, 2, 3],
436+
b=[[1, 2], missing, [3, 4]],
437+
c=[[5, 6], missing, missing])
438+
@test flatten(df, :a) df
439+
@test_throws MethodError flatten(df, :b)
440+
@test flatten(df, :b, scalar=Missing)
441+
DataFrame(a=[1, 1, 2, 3, 3],
442+
b=[1, 2, missing, 3, 4],
443+
c=[[5, 6], [5, 6], missing, missing, missing])
444+
@test flatten(df, [:b, :c], scalar=Missing)
445+
DataFrame(a=[1, 1, 2, 3, 3],
446+
b=[1, 2, missing, 3, 4],
447+
c=[5, 6, missing, missing, missing])
448+
@test flatten(df, [:b, :c], scalar=Any) df
449+
450+
df = DataFrame(a=missing, b=[1], c=missing, d=[[1, 2]])
451+
@test_throws ArgumentError flatten(df, All(), scalar=Missing)
452+
@test flatten(df, Not(:d), scalar=Missing)
453+
DataFrame(a=missing, b=1, c=missing, d=[[1, 2]])
454+
@test flatten(df, Not(:b), scalar=Missing)
455+
DataFrame(a=[missing, missing], b=[1, 1], c=[missing, missing], d=[1, 2])
456+
457+
df = DataFrame(a="xy", b=[[1, 2]])
458+
@test flatten(df, [:a, :b]) == DataFrame(a=['x', 'y'], b=[1, 2])
459+
@test flatten(df, [:a, :b], scalar=String) ==
460+
DataFrame(a=["xy", "xy"], b=[1, 2])
461+
462+
df = DataFrame(a=[[1], [], [3, 4], missing], b = missings(4), id=1:4)
463+
@test flatten(df, [:a, :b], scalar=Missing)
464+
DataFrame(a=[1, 3, 4, missing], b=missings(4), id=[1, 3, 3, 4])
465+
df = DataFrame(id=1:10, x=[1:i-1 for i in 1:10])
466+
df.y = [iseven(last(v)) ? missing : v for v in df.x]
467+
@test flatten(df, [:x, :y], scalar=Missing)
468+
DataFrame(id=reduce(vcat, [fill(i, i-1) for i in 2:10]),
469+
x=reduce(vcat, [1:i for i in 1:9]),
470+
y=reduce(vcat, [iseven(i) ? missings(i) : (1:i) for i in 1:9]))
471+
472+
# Below are tests showing handling of strings
473+
df = DataFrame(id=1:5,
474+
col1=["a", missing, 1:2, 3:4, 5:6],
475+
col2=[11:12, 111:112, 1111:1112, missing, "b"])
476+
@test flatten(df, [:col1, :col2], scalar=Union{Missing, AbstractString})
477+
DataFrame(id=[1 ,1, 2, 2, 3, 3, 4, 4, 5, 5],
478+
col1=["a", "a", missing, missing, 1, 2, 3, 4, 5, 6],
479+
col2=[11, 12, 111, 112, 1111, 1112, missing, missing, "b", "b"])
480+
@test_throws MethodError flatten(df, [:col1, :col2])
481+
@test_throws ArgumentError flatten(df, [:col1, :col2], scalar=Missing)
482+
@test_throws MethodError flatten(df, [:col1, :col2], scalar=AbstractString)
483+
484+
df = DataFrame(id=1:5,
485+
col1=["ab", missing, 1:2, 3:4, 5:6],
486+
col2=[11:12, 111:112, 1111:1112, missing, "cd"])
487+
@test flatten(df, [:col1, :col2], scalar=Union{Missing, AbstractString})
488+
DataFrame(id=[1 ,1, 2, 2, 3, 3, 4, 4, 5, 5],
489+
col1=["ab", "ab", missing, missing, 1, 2, 3, 4, 5, 6],
490+
col2=[11, 12, 111, 112, 1111, 1112, missing, missing, "cd", "cd"])
491+
@test_throws MethodError flatten(df, [:col1, :col2])
492+
@test flatten(df, [:col1, :col2], scalar=Missing)
493+
DataFrame(id=[1 ,1, 2, 2, 3, 3, 4, 4, 5, 5],
494+
col1=['a', 'b', missing, missing, 1, 2, 3, 4, 5, 6],
495+
col2=[11, 12, 111, 112, 1111, 1112, missing, missing, 'c', 'd'])
496+
@test_throws MethodError flatten(df, [:col1, :col2], scalar=AbstractString)
497+
end
498+
434499
@testset "stack categorical test" begin
435500
Random.seed!(1234)
436501
d1 = DataFrame(a=repeat([1:3;], inner=[4]),

0 commit comments

Comments
 (0)