Skip to content

Commit a74e4b0

Browse files
committed
Closes #5089: Add a balancedSegStringRetrieval function
1 parent 6243a6f commit a74e4b0

File tree

3 files changed

+628
-7
lines changed

3 files changed

+628
-7
lines changed

arkouda/numpy/strings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,6 +3040,40 @@ def concatenate_uniquely(strings: List[Strings]) -> Strings:
30403040

30413041
return Strings.from_return_msg(cast(str, rep_msg))
30423042

3043+
@staticmethod
3044+
def concatenate_uniquely2(strings: List[Strings]) -> Strings:
3045+
"""
3046+
Concatenates a list of Strings into a single Strings object
3047+
containing only unique strings. Order may not be preserved.
3048+
3049+
Parameters
3050+
----------
3051+
strings : List[Strings]
3052+
List of segmented string objects to concatenate.
3053+
3054+
Returns
3055+
-------
3056+
Strings
3057+
A new Strings object containing the unique values.
3058+
"""
3059+
from arkouda.client import generic_msg
3060+
3061+
if not strings:
3062+
raise ValueError("Must provide at least one Strings object")
3063+
3064+
# Extract name of each SegmentedString
3065+
names = [s.name for s in strings]
3066+
3067+
# Send the command to the server
3068+
rep_msg = generic_msg(
3069+
cmd="concatenateUniquely2",
3070+
args={
3071+
"names": names,
3072+
},
3073+
)
3074+
3075+
return Strings.from_return_msg(cast(str, rep_msg))
3076+
30433077
def argsort(
30443078
self,
30453079
algorithm: SortingAlgorithm = SortingAlgorithm.RadixSortLSD,

src/ConcatenateMsg.chpl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,4 +653,100 @@ module ConcatenateMsg
653653
return new MsgTuple(repMsg, MsgType.NORMAL);
654654
}
655655
registerFunction("concatenateUniquely", concatenateUniqueStrMsg, getModuleName());
656+
657+
proc concatenateUniqueStrMsg2(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
658+
param pn = Reflection.getRoutineName();
659+
660+
var repMsg: string;
661+
var names = msgArgs.get("names").toScalarList(string);
662+
var n = names.size;
663+
664+
665+
cmLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
666+
"concatenate unique strings from %i arrays: %?".format(n, names));
667+
668+
// Each locale gets its own set
669+
var localeSets = makeDistArray(numLocales, set(string));
670+
671+
// Initialize sets
672+
coforall loc in Locales do on loc {
673+
localeSets[here.id] = new set(string);
674+
}
675+
676+
var stringsInLocale: [PrivateSpace] innerArray(string);
677+
678+
// Collect all unique strings from each input SegmentedString
679+
for i in 0..#names.size {
680+
var rawName = names[i];
681+
var (strName, _) = rawName.splitMsgToTuple('+', 2);
682+
try {
683+
var segString = getSegString(strName, st);
684+
cmLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
685+
"Processing SegString: %s".format(strName));
686+
687+
// Grab the strings by locale and throw them in the sets.
688+
var stringArrays = balancedSegStringRetrieval(segString);
689+
coforall loc in Locales do on loc {
690+
ref myStrings = stringArrays[here.id].Arr;
691+
var locSet = new set(string);
692+
693+
// Reduce to unique strings within this locale
694+
forall str in myStrings with (+ reduce locSet) {
695+
locSet.add(str);
696+
}
697+
698+
// Throw it in my locale's set
699+
localeSets[here.id] |= locSet;
700+
701+
// Convert to innerArray
702+
if i == names.size - 1 {
703+
stringsInLocale[here.id] = new innerArray({0..#localeSets[here.id].size}, string);
704+
stringsInLocale[here.id].Arr = localeSets[here.id].toArray();
705+
}
706+
}
707+
708+
} catch e: Error {
709+
throw getErrorWithContext(
710+
msg="lookup for %s failed".format(rawName),
711+
lineNumber=getLineNumber(),
712+
routineName=getRoutineName(),
713+
moduleName=getModuleName(),
714+
errorClass="UnknownSymbolError");
715+
}
716+
}
717+
718+
// Repartition the strings by their hash
719+
var distributedStrings = repartitionByHashArray(string, stringsInLocale);
720+
721+
coforall loc in Locales do on loc {
722+
723+
ref myStrings = distributedStrings[here.id].Arr;
724+
var strSet = new set(string);
725+
726+
// Perform another reduction by uniqueness on the strings in this set
727+
forall str in myStrings with (+ reduce strSet) {
728+
strSet.add(str);
729+
}
730+
731+
// This is maybe a little unusual. I tried just overwriting myStrings directly
732+
// But I think that caused some memory error for some reason.
733+
// I think because the domain didn't match...?
734+
distributedStrings[here.id] = new innerArray({0..#strSet.size}, string);
735+
ref myStrings2 = distributedStrings[here.id].Arr;
736+
myStrings2 = strSet.toArray();
737+
738+
}
739+
740+
// Convert back from innerArray(string) to SegString
741+
var retString = segStringFromInnerArray(distributedStrings, st);
742+
743+
// Store the result in the symbol table and return
744+
repMsg = "created " + st.attrib(retString.name) + "+created bytes.size %?".format(retString.nBytes);
745+
746+
cmLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
747+
"Created unique concatenated SegmentedString: %s".format(st.attrib(retString.name)));
748+
749+
return new MsgTuple(repMsg, MsgType.NORMAL);
750+
}
751+
registerFunction("concatenateUniquely2", concatenateUniqueStrMsg2, getModuleName());
656752
}

0 commit comments

Comments
 (0)