Skip to content

Commit b8ac505

Browse files
committed
Improve test coverage and type of exceptions thrown for unsuported stuff
1 parent 6323969 commit b8ac505

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/FastBertTokenizer.Tests/BatchEnumerators.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,30 @@ public void TokenizeWithBatchEnumerator(Dictionary<int, string> articles)
3434
}
3535
}
3636

37+
[Theory]
38+
[MemberData(nameof(WikipediaSimpleData.GetArticlesDict), MemberType = typeof(WikipediaSimpleData))]
39+
public void TokenizeWithBatchEnumeratorNonGenericEnumerable(Dictionary<int, string> articles)
40+
{
41+
IEnumerable<(int, string)> Source()
42+
{
43+
foreach (var (key, value) in articles.Take(100)) // some data is sufficient here, as TokenizeWithBatchEnumerator already checks the same.
44+
{
45+
yield return (key, value);
46+
}
47+
}
48+
49+
var enumerable = (System.Collections.IEnumerable)_uut.CreateBatchEnumerator(Source(), 512, 10, 0);
50+
foreach (var batch in enumerable)
51+
{
52+
var casted = (TokenizedBatch<int>)batch;
53+
casted.InputIds.Span[0].ShouldBe(101); // [CLS] = 101
54+
}
55+
56+
var anotherEnumerable = (System.Collections.IEnumerable)_uut.CreateBatchEnumerator(Source(), 512, 10, 0);
57+
var enumerator = anotherEnumerable.GetEnumerator();
58+
Should.Throw<InvalidOperationException>(enumerator.Reset);
59+
}
60+
3761
[Theory]
3862
[MemberData(nameof(WikipediaSimpleData.GetArticlesDict), MemberType = typeof(WikipediaSimpleData))]
3963
public async Task TokenizeWithAsyncBatchEnumerator(Dictionary<int, string> articles)

src/FastBertTokenizer/AsyncBatchEnumerator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,15 @@ public bool MoveNext()
147147
var vt = parent.MoveNextAsync();
148148
if (!vt.IsCompletedSuccessfully)
149149
{
150-
throw new NotImplementedException();
150+
throw new NotImplementedException(); // UnreachableException is only available in .NET 7+
151151
}
152152
else
153153
{
154154
return vt.Result;
155155
}
156156
}
157157

158-
public void Reset() => throw new InvalidOperationException("Multiple enumeration is not supported.");
158+
public void Reset() => throw new NotSupportedException("Multiple enumeration is not supported.");
159159
}
160160
}
161161

0 commit comments

Comments
 (0)