1
+ using System ;
2
+ using System . Diagnostics ;
3
+ using System . Runtime . InteropServices ;
4
+ using System . Text ;
5
+ namespace GemmaCpp
6
+ {
7
+ public class GemmaException : Exception
8
+ {
9
+ public GemmaException ( string message ) : base ( message ) { }
10
+ }
11
+
12
+ public class Gemma : IDisposable
13
+ {
14
+ private IntPtr _context ;
15
+ private bool _disposed ;
16
+
17
+ // Optional: Allow setting DLL path
18
+ public static string DllPath { get ; set ; } = "gemma.dll" ;
19
+
20
+ [ DllImport ( "kernel32.dll" , CharSet = CharSet . Unicode , SetLastError = true ) ]
21
+ private static extern IntPtr LoadLibrary ( string lpFileName ) ;
22
+
23
+ static Gemma ( )
24
+ {
25
+ // Load DLL from specified path
26
+ if ( LoadLibrary ( DllPath ) == IntPtr . Zero )
27
+ {
28
+ throw new DllNotFoundException ( $ "Failed to load { DllPath } . Error: { Marshal . GetLastWin32Error ( ) } ") ;
29
+ }
30
+ }
31
+
32
+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
33
+ private static extern IntPtr GemmaCreate (
34
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string tokenizerPath ,
35
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string modelType ,
36
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string weightsPath ,
37
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string weightType ,
38
+ int maxLength ) ;
39
+
40
+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
41
+ private static extern void GemmaDestroy ( IntPtr context ) ;
42
+
43
+ // Delegate type for token callbacks
44
+ public delegate bool TokenCallback ( string token ) ;
45
+
46
+ // Keep delegate alive for duration of calls
47
+ private GCHandle _callbackHandle ;
48
+
49
+ [ UnmanagedFunctionPointer ( CallingConvention . Cdecl ) ]
50
+ private delegate bool GemmaTokenCallback (
51
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string text ,
52
+ IntPtr userData ) ;
53
+
54
+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
55
+ private static extern int GemmaGenerate (
56
+ IntPtr context ,
57
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string prompt ,
58
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] StringBuilder output ,
59
+ int maxLength ,
60
+ GemmaTokenCallback callback ,
61
+ IntPtr userData ) ;
62
+
63
+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
64
+ private static extern int GemmaCountTokens (
65
+ IntPtr context ,
66
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string text ) ;
67
+
68
+ // Native callback delegate type
69
+ [ UnmanagedFunctionPointer ( CallingConvention . Cdecl ) ]
70
+ private delegate void GemmaLogCallback (
71
+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string message ,
72
+ IntPtr userData ) ;
73
+
74
+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
75
+ private static extern void GemmaSetLogCallback (
76
+ IntPtr context ,
77
+ GemmaLogCallback callback ,
78
+ IntPtr userData ) ;
79
+
80
+ private GCHandle _logCallbackHandle ;
81
+
82
+ public Gemma ( string tokenizerPath , string modelType , string weightsPath , string weightType , int maxLength = 8192 )
83
+ {
84
+ _context = GemmaCreate ( tokenizerPath , modelType , weightsPath , weightType , maxLength ) ;
85
+ if ( _context == IntPtr . Zero )
86
+ {
87
+ throw new GemmaException ( "Failed to create Gemma context" ) ;
88
+ }
89
+
90
+ // optionally: set up logging
91
+ /*
92
+ GemmaLogCallback logCallback = (message, _) =>
93
+ {
94
+ #if UNITY_ENGINE
95
+ Debug.Log($"Gemma: {message}");
96
+ #else
97
+ Debug.WriteLine($"Gemma: {message}");
98
+ #endif
99
+ };
100
+ _logCallbackHandle = GCHandle.Alloc(logCallback);
101
+ GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
102
+ */
103
+ }
104
+
105
+ public int CountTokens ( string prompt )
106
+ {
107
+ if ( _disposed )
108
+ throw new ObjectDisposedException ( nameof ( Gemma ) ) ;
109
+
110
+ if ( _context == IntPtr . Zero )
111
+ throw new GemmaException ( "Gemma context is invalid" ) ;
112
+ int count = GemmaCountTokens ( _context , prompt ) ;
113
+ return count ;
114
+ }
115
+
116
+ public string Generate ( string prompt , int maxLength = 4096 )
117
+ {
118
+ return Generate ( prompt , null , maxLength ) ;
119
+ }
120
+
121
+ public string Generate ( string prompt , TokenCallback callback , int maxLength = 4096 )
122
+ {
123
+ if ( _disposed )
124
+ throw new ObjectDisposedException ( nameof ( Gemma ) ) ;
125
+
126
+ if ( _context == IntPtr . Zero )
127
+ throw new GemmaException ( "Gemma context is invalid" ) ;
128
+
129
+ var output = new StringBuilder ( maxLength ) ;
130
+ GemmaTokenCallback nativeCallback = null ;
131
+
132
+ if ( callback != null )
133
+ {
134
+ nativeCallback = ( text , _ ) => callback ( text ) ;
135
+ _callbackHandle = GCHandle . Alloc ( nativeCallback ) ;
136
+ }
137
+
138
+ try
139
+ {
140
+ int length = GemmaGenerate ( _context , prompt , output , maxLength ,
141
+ nativeCallback , IntPtr . Zero ) ;
142
+
143
+ if ( length < 0 )
144
+ throw new GemmaException ( "Generation failed" ) ;
145
+
146
+ return output . ToString ( ) ;
147
+ }
148
+ finally
149
+ {
150
+ if ( _callbackHandle . IsAllocated )
151
+ _callbackHandle . Free ( ) ;
152
+ }
153
+ }
154
+
155
+ public void Dispose ( )
156
+ {
157
+ if ( ! _disposed )
158
+ {
159
+ if ( _context != IntPtr . Zero )
160
+ {
161
+ GemmaDestroy ( _context ) ;
162
+ _context = IntPtr . Zero ;
163
+ }
164
+ if ( _logCallbackHandle . IsAllocated )
165
+ _logCallbackHandle . Free ( ) ;
166
+ _disposed = true ;
167
+ }
168
+ }
169
+
170
+ ~ Gemma ( )
171
+ {
172
+ Dispose ( ) ;
173
+ }
174
+ }
175
+ }
0 commit comments