Skip to content

Commit 9cff956

Browse files
test: Add comprehensive edge case tests for TokenSplitter
1 parent 57a517b commit 9cff956

File tree

1 file changed

+354
-0
lines changed

1 file changed

+354
-0
lines changed
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
package textsplitter
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
"github.com/tmc/langchaingo/schema"
10+
)
11+
12+
func TestTokenSplitterEdgeCases(t *testing.T) {
13+
t.Parallel()
14+
15+
t.Run("empty text", func(t *testing.T) {
16+
t.Parallel()
17+
splitter := NewTokenSplitter()
18+
docs, err := CreateDocuments(splitter, []string{""}, nil)
19+
require.NoError(t, err)
20+
assert.Equal(t, []schema.Document{
21+
{PageContent: "", Metadata: map[string]any{}},
22+
}, docs)
23+
})
24+
25+
t.Run("single character", func(t *testing.T) {
26+
t.Parallel()
27+
splitter := NewTokenSplitter(WithChunkSize(1))
28+
docs, err := CreateDocuments(splitter, []string{"a"}, nil)
29+
require.NoError(t, err)
30+
assert.Equal(t, []schema.Document{
31+
{PageContent: "a", Metadata: map[string]any{}},
32+
}, docs)
33+
})
34+
35+
t.Run("whitespace only", func(t *testing.T) {
36+
t.Parallel()
37+
splitter := NewTokenSplitter(WithChunkSize(10))
38+
docs, err := CreateDocuments(splitter, []string{" \n\t "}, nil)
39+
require.NoError(t, err)
40+
assert.Len(t, docs, 1)
41+
assert.Equal(t, " \n\t ", docs[0].PageContent)
42+
})
43+
44+
t.Run("very large chunk size", func(t *testing.T) {
45+
t.Parallel()
46+
text := "This is a test text that should not be split because the chunk size is very large."
47+
splitter := NewTokenSplitter(WithChunkSize(10000))
48+
docs, err := CreateDocuments(splitter, []string{text}, nil)
49+
require.NoError(t, err)
50+
assert.Len(t, docs, 1)
51+
assert.Equal(t, text, docs[0].PageContent)
52+
})
53+
54+
t.Run("zero chunk overlap", func(t *testing.T) {
55+
t.Parallel()
56+
text := "This is a longer text that should be split into multiple chunks without any overlap between them."
57+
splitter := NewTokenSplitter(
58+
WithChunkSize(10),
59+
WithChunkOverlap(0),
60+
)
61+
docs, err := CreateDocuments(splitter, []string{text}, nil)
62+
require.NoError(t, err)
63+
assert.Greater(t, len(docs), 1)
64+
65+
// Verify no overlap by checking that no two consecutive chunks share content
66+
for i := 1; i < len(docs); i++ {
67+
prev := strings.TrimSpace(docs[i-1].PageContent)
68+
curr := strings.TrimSpace(docs[i].PageContent)
69+
if prev != "" && curr != "" {
70+
// Should not have overlapping words at boundaries
71+
prevWords := strings.Fields(prev)
72+
currWords := strings.Fields(curr)
73+
if len(prevWords) > 0 && len(currWords) > 0 {
74+
assert.NotEqual(t, prevWords[len(prevWords)-1], currWords[0],
75+
"Chunks should not overlap when overlap is 0")
76+
}
77+
}
78+
}
79+
})
80+
81+
t.Run("chunk overlap equals chunk size", func(t *testing.T) {
82+
t.Parallel()
83+
text := "Word1 Word2 Word3 Word4 Word5 Word6 Word7 Word8"
84+
splitter := NewTokenSplitter(
85+
WithChunkSize(5),
86+
WithChunkOverlap(5),
87+
)
88+
docs, err := CreateDocuments(splitter, []string{text}, nil)
89+
require.NoError(t, err)
90+
assert.Greater(t, len(docs), 1)
91+
92+
// With overlap equal to chunk size, chunks should have significant overlap
93+
for i := 1; i < len(docs); i++ {
94+
assert.NotEmpty(t, docs[i].PageContent)
95+
}
96+
})
97+
98+
t.Run("unicode and special characters", func(t *testing.T) {
99+
t.Parallel()
100+
text := "Hello 世界! 🌍 This contains émojis and spëcial characters: àáâãäå"
101+
splitter := NewTokenSplitter(WithChunkSize(20))
102+
docs, err := CreateDocuments(splitter, []string{text}, nil)
103+
require.NoError(t, err)
104+
assert.NotEmpty(t, docs)
105+
106+
// Verify all content is preserved
107+
combined := ""
108+
for _, doc := range docs {
109+
combined += doc.PageContent
110+
}
111+
// Remove potential whitespace differences for comparison
112+
assert.Contains(t, combined, "世界")
113+
assert.Contains(t, combined, "🌍")
114+
assert.Contains(t, combined, "émojis")
115+
assert.Contains(t, combined, "àáâãäå")
116+
})
117+
118+
t.Run("very long single word", func(t *testing.T) {
119+
t.Parallel()
120+
longWord := strings.Repeat("a", 1000)
121+
splitter := NewTokenSplitter(WithChunkSize(10))
122+
docs, err := CreateDocuments(splitter, []string{longWord}, nil)
123+
require.NoError(t, err)
124+
assert.Greater(t, len(docs), 1)
125+
126+
// Verify the long word is split appropriately
127+
combined := ""
128+
for _, doc := range docs {
129+
combined += doc.PageContent
130+
}
131+
assert.Equal(t, longWord, combined)
132+
})
133+
134+
t.Run("newlines and formatting preservation", func(t *testing.T) {
135+
t.Parallel()
136+
text := "Line 1\n\nLine 2\n\n\nLine 3\n\t\tIndented line\n"
137+
splitter := NewTokenSplitter(WithChunkSize(15))
138+
docs, err := CreateDocuments(splitter, []string{text}, nil)
139+
require.NoError(t, err)
140+
assert.NotEmpty(t, docs)
141+
142+
// Verify formatting is preserved in the split
143+
combined := ""
144+
for _, doc := range docs {
145+
combined += doc.PageContent
146+
}
147+
assert.Contains(t, combined, "\n\n")
148+
assert.Contains(t, combined, "\t\t")
149+
})
150+
}
151+
152+
func TestTokenSplitterDifferentModels(t *testing.T) {
153+
t.Parallel()
154+
155+
testCases := []struct {
156+
name string
157+
modelName string
158+
text string
159+
}{
160+
{
161+
name: "gpt-4 model",
162+
modelName: "gpt-4",
163+
text: "This is a test for GPT-4 tokenization.",
164+
},
165+
{
166+
name: "gpt-3.5-turbo model",
167+
modelName: "gpt-3.5-turbo",
168+
text: "This is a test for GPT-3.5-turbo tokenization.",
169+
},
170+
{
171+
name: "text-davinci-003 model",
172+
modelName: "text-davinci-003",
173+
text: "This is a test for text-davinci-003 tokenization.",
174+
},
175+
}
176+
177+
for _, tc := range testCases {
178+
tc := tc
179+
t.Run(tc.name, func(t *testing.T) {
180+
t.Parallel()
181+
splitter := NewTokenSplitter(
182+
WithModelName(tc.modelName),
183+
WithChunkSize(10),
184+
)
185+
docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
186+
require.NoError(t, err)
187+
assert.NotEmpty(t, docs)
188+
189+
// Verify content is preserved
190+
combined := ""
191+
for _, doc := range docs {
192+
combined += doc.PageContent
193+
}
194+
assert.Contains(t, combined, "test")
195+
assert.Contains(t, combined, "tokenization")
196+
})
197+
}
198+
}
199+
200+
func TestTokenSplitterDifferentEncodings(t *testing.T) {
201+
t.Parallel()
202+
203+
testCases := []struct {
204+
name string
205+
encodingName string
206+
text string
207+
}{
208+
{
209+
name: "cl100k_base encoding",
210+
encodingName: "cl100k_base",
211+
text: "Testing cl100k_base encoding with various tokens.",
212+
},
213+
{
214+
name: "p50k_base encoding",
215+
encodingName: "p50k_base",
216+
text: "Testing p50k_base encoding with various tokens.",
217+
},
218+
{
219+
name: "r50k_base encoding",
220+
encodingName: "r50k_base",
221+
text: "Testing r50k_base encoding with various tokens.",
222+
},
223+
}
224+
225+
for _, tc := range testCases {
226+
tc := tc
227+
t.Run(tc.name, func(t *testing.T) {
228+
t.Parallel()
229+
splitter := NewTokenSplitter(
230+
WithEncodingName(tc.encodingName),
231+
WithChunkSize(15),
232+
)
233+
docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
234+
require.NoError(t, err)
235+
assert.NotEmpty(t, docs)
236+
237+
// Verify content is preserved
238+
combined := ""
239+
for _, doc := range docs {
240+
combined += doc.PageContent
241+
}
242+
assert.Contains(t, combined, "Testing")
243+
assert.Contains(t, combined, "encoding")
244+
})
245+
}
246+
}
247+
248+
func TestTokenSplitterSpecialTokens(t *testing.T) {
249+
t.Parallel()
250+
251+
t.Run("with allowed special tokens", func(t *testing.T) {
252+
t.Parallel()
253+
text := "This text contains <|endoftext|> special token."
254+
splitter := NewTokenSplitter(
255+
WithAllowedSpecial([]string{"<|endoftext|>"}),
256+
WithChunkSize(20),
257+
)
258+
docs, err := CreateDocuments(splitter, []string{text}, nil)
259+
require.NoError(t, err)
260+
assert.NotEmpty(t, docs)
261+
262+
// Verify special token is preserved
263+
combined := ""
264+
for _, doc := range docs {
265+
combined += doc.PageContent
266+
}
267+
assert.Contains(t, combined, "<|endoftext|>")
268+
})
269+
270+
t.Run("with disallowed special tokens", func(t *testing.T) {
271+
t.Parallel()
272+
text := "This is normal text without special tokens."
273+
splitter := NewTokenSplitter(
274+
WithDisallowedSpecial([]string{"all"}),
275+
WithChunkSize(20),
276+
)
277+
docs, err := CreateDocuments(splitter, []string{text}, nil)
278+
require.NoError(t, err)
279+
assert.NotEmpty(t, docs)
280+
281+
// Verify content is preserved
282+
combined := ""
283+
for _, doc := range docs {
284+
combined += doc.PageContent
285+
}
286+
assert.Contains(t, combined, "normal text")
287+
})
288+
}
289+
290+
func TestTokenSplitterErrorHandling(t *testing.T) {
291+
t.Parallel()
292+
293+
t.Run("invalid model name", func(t *testing.T) {
294+
t.Parallel()
295+
splitter := NewTokenSplitter(WithModelName("invalid-model-name"))
296+
_, err := CreateDocuments(splitter, []string{"test"}, nil)
297+
assert.Error(t, err)
298+
assert.Contains(t, err.Error(), "tiktoken")
299+
})
300+
301+
t.Run("invalid encoding name", func(t *testing.T) {
302+
t.Parallel()
303+
splitter := NewTokenSplitter(WithEncodingName("invalid-encoding"))
304+
_, err := CreateDocuments(splitter, []string{"test"}, nil)
305+
assert.Error(t, err)
306+
assert.Contains(t, err.Error(), "tiktoken")
307+
})
308+
}
309+
310+
func TestTokenSplitterConsistency(t *testing.T) {
311+
t.Parallel()
312+
313+
t.Run("consistent splitting", func(t *testing.T) {
314+
t.Parallel()
315+
text := "This is a consistent test text that should be split the same way every time."
316+
splitter := NewTokenSplitter(
317+
WithChunkSize(10),
318+
WithChunkOverlap(2),
319+
)
320+
321+
// Split the same text multiple times
322+
docs1, err1 := CreateDocuments(splitter, []string{text}, nil)
323+
require.NoError(t, err1)
324+
325+
docs2, err2 := CreateDocuments(splitter, []string{text}, nil)
326+
require.NoError(t, err2)
327+
328+
// Results should be identical
329+
assert.Equal(t, docs1, docs2)
330+
})
331+
332+
t.Run("order preservation", func(t *testing.T) {
333+
t.Parallel()
334+
text := "First sentence. Second sentence. Third sentence. Fourth sentence."
335+
splitter := NewTokenSplitter(WithChunkSize(8))
336+
docs, err := CreateDocuments(splitter, []string{text}, nil)
337+
require.NoError(t, err)
338+
339+
// Verify order is preserved
340+
combined := ""
341+
for _, doc := range docs {
342+
combined += doc.PageContent
343+
}
344+
345+
firstPos := strings.Index(combined, "First")
346+
secondPos := strings.Index(combined, "Second")
347+
thirdPos := strings.Index(combined, "Third")
348+
fourthPos := strings.Index(combined, "Fourth")
349+
350+
assert.True(t, firstPos < secondPos)
351+
assert.True(t, secondPos < thirdPos)
352+
assert.True(t, thirdPos < fourthPos)
353+
})
354+
}

0 commit comments

Comments
 (0)