Skip to content

Commit a0a08d7

Browse files
The gemma.cpp Authorscopybara-github
The gemma.cpp Authors
authored andcommitted
Adds:
- GemmaContext class that exposes Gemma functionality - C API that uses GemmaContext - C# interop class in GemmaInterop.cs - New END_OF_TURN_ID in tokenizer.h, useful when dealing with instruction-tuned prompts PiperOrigin-RevId: 730754638
1 parent b3b4b9f commit a0a08d7

File tree

7 files changed

+552
-0
lines changed

7 files changed

+552
-0
lines changed

CMakeLists.txt

+38
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ set(SOURCES
110110
util/threading.h
111111
)
112112

113+
# Add C API sources only when building DLL
114+
if(BUILD_GEMMA_DLL)
115+
list(APPEND SOURCES
116+
gemma/context.h
117+
gemma/context.cc
118+
gemma/c_api.h
119+
gemma/c_api.cc
120+
)
121+
message(STATUS "Including C API files for DLL build")
122+
endif()
123+
113124
if(NOT CMAKE_BUILD_TYPE)
114125
set(CMAKE_BUILD_TYPE "Release")
115126
endif()
@@ -129,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
129140
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
130141
install(TARGETS libgemma DESTINATION lib)
131142

143+
# Shared Library Target for C# interop
144+
if(BUILD_GEMMA_DLL)
145+
add_library(gemma_shared SHARED ${SOURCES})
146+
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
147+
set_target_properties(gemma_shared PROPERTIES
148+
PREFIX ""
149+
OUTPUT_NAME "gemma"
150+
)
151+
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
152+
target_include_directories(gemma_shared PUBLIC ./)
153+
target_link_libraries(gemma_shared PRIVATE
154+
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
155+
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
156+
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
157+
)
158+
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
159+
target_compile_definitions(gemma_shared
160+
PRIVATE
161+
GEMMA_EXPORTS
162+
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
163+
)
164+
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
165+
install(TARGETS gemma_shared DESTINATION lib)
166+
install(FILES gemma/c_api.h DESTINATION include/gemma)
167+
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
168+
endif()
169+
132170
# Executable Target
133171

134172
add_executable(gemma gemma/run.cc)

GemmaInterop.cs

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
using System;
2+
using System.Diagnostics;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
namespace GemmaCpp
6+
{
7+
public class GemmaException : Exception
8+
{
9+
public GemmaException(string message) : base(message) { }
10+
}
11+
12+
public class Gemma : IDisposable
13+
{
14+
private IntPtr _context;
15+
private bool _disposed;
16+
17+
// Optional: Allow setting DLL path
18+
public static string DllPath { get; set; } = "gemma.dll";
19+
20+
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
21+
private static extern IntPtr LoadLibrary(string lpFileName);
22+
23+
static Gemma()
24+
{
25+
// Load DLL from specified path
26+
if (LoadLibrary(DllPath) == IntPtr.Zero)
27+
{
28+
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
29+
}
30+
}
31+
32+
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
33+
private static extern IntPtr GemmaCreate(
34+
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
35+
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
36+
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
37+
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
38+
int maxLength);
39+
40+
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
41+
private static extern void GemmaDestroy(IntPtr context);
42+
43+
// Delegate type for token callbacks
44+
public delegate bool TokenCallback(string token);
45+
46+
// Keep delegate alive for duration of calls
47+
private GCHandle _callbackHandle;
48+
49+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
50+
private delegate bool GemmaTokenCallback(
51+
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
52+
IntPtr userData);
53+
54+
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
55+
private static extern int GemmaGenerate(
56+
IntPtr context,
57+
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
58+
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output,
59+
int maxLength,
60+
GemmaTokenCallback callback,
61+
IntPtr userData);
62+
63+
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
64+
private static extern int GemmaCountTokens(
65+
IntPtr context,
66+
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
67+
68+
// Native callback delegate type
69+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
70+
private delegate void GemmaLogCallback(
71+
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
72+
IntPtr userData);
73+
74+
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
75+
private static extern void GemmaSetLogCallback(
76+
IntPtr context,
77+
GemmaLogCallback callback,
78+
IntPtr userData);
79+
80+
private GCHandle _logCallbackHandle;
81+
82+
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
83+
{
84+
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
85+
if (_context == IntPtr.Zero)
86+
{
87+
throw new GemmaException("Failed to create Gemma context");
88+
}
89+
90+
// optionally: set up logging
91+
/*
92+
GemmaLogCallback logCallback = (message, _) =>
93+
{
94+
#if UNITY_ENGINE
95+
Debug.Log($"Gemma: {message}");
96+
#else
97+
Debug.WriteLine($"Gemma: {message}");
98+
#endif
99+
};
100+
_logCallbackHandle = GCHandle.Alloc(logCallback);
101+
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
102+
*/
103+
}
104+
105+
public int CountTokens(string prompt)
106+
{
107+
if (_disposed)
108+
throw new ObjectDisposedException(nameof(Gemma));
109+
110+
if (_context == IntPtr.Zero)
111+
throw new GemmaException("Gemma context is invalid");
112+
int count = GemmaCountTokens(_context, prompt);
113+
return count;
114+
}
115+
116+
public string Generate(string prompt, int maxLength = 4096)
117+
{
118+
return Generate(prompt, null, maxLength);
119+
}
120+
121+
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
122+
{
123+
if (_disposed)
124+
throw new ObjectDisposedException(nameof(Gemma));
125+
126+
if (_context == IntPtr.Zero)
127+
throw new GemmaException("Gemma context is invalid");
128+
129+
var output = new StringBuilder(maxLength);
130+
GemmaTokenCallback nativeCallback = null;
131+
132+
if (callback != null)
133+
{
134+
nativeCallback = (text, _) => callback(text);
135+
_callbackHandle = GCHandle.Alloc(nativeCallback);
136+
}
137+
138+
try
139+
{
140+
int length = GemmaGenerate(_context, prompt, output, maxLength,
141+
nativeCallback, IntPtr.Zero);
142+
143+
if (length < 0)
144+
throw new GemmaException("Generation failed");
145+
146+
return output.ToString();
147+
}
148+
finally
149+
{
150+
if (_callbackHandle.IsAllocated)
151+
_callbackHandle.Free();
152+
}
153+
}
154+
155+
public void Dispose()
156+
{
157+
if (!_disposed)
158+
{
159+
if (_context != IntPtr.Zero)
160+
{
161+
GemmaDestroy(_context);
162+
_context = IntPtr.Zero;
163+
}
164+
if (_logCallbackHandle.IsAllocated)
165+
_logCallbackHandle.Free();
166+
_disposed = true;
167+
}
168+
}
169+
170+
~Gemma()
171+
{
172+
Dispose();
173+
}
174+
}
175+
}

gemma/c_api.cc

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#ifndef GEMMA_EXPORTS
2+
#define GEMMA_EXPORTS
3+
#endif
4+
5+
#include "gemma/c_api.h"
6+
7+
#include "util/app.h"
8+
9+
extern "C" {
10+
11+
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
12+
const char* model_type,
13+
const char* weights_path,
14+
const char* weight_type, int max_length) {
15+
try {
16+
// kludge
17+
gcpp::AppArgs app_args;
18+
app_args.Init();
19+
app_args.max_packages = 1;
20+
app_args.verbosity = 0;
21+
app_args.spin = gcpp::Tristate::kFalse;
22+
23+
return new GemmaContext(tokenizer_path, model_type, weights_path,
24+
weight_type, app_args, max_length);
25+
} catch (...) {
26+
return nullptr;
27+
}
28+
}
29+
30+
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
31+
delete static_cast<gcpp::GemmaContext*>(ctx);
32+
}
33+
34+
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
35+
int max_length, GemmaTokenCallback callback,
36+
void* user_data) {
37+
if (!ctx) return -1;
38+
return static_cast<gcpp::GemmaContext*>(ctx)->Generate(
39+
prompt, output, max_length, callback, user_data);
40+
}
41+
42+
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
43+
if (!ctx || !text) return -1;
44+
return static_cast<gcpp::GemmaContext*>(ctx)->CountTokens(text);
45+
}
46+
47+
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
48+
void* user_data) {
49+
if (!ctx) return;
50+
static_cast<gcpp::GemmaContext*>(ctx)->SetLogCallback(callback, user_data);
51+
}
52+
}

gemma/c_api.h

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2024 Google LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// https://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#ifndef THIRD_PARTY_GEMMA_C_API_H_
17+
#define THIRD_PARTY_GEMMA_C_API_H_
18+
19+
#include "gemma/context.h"
20+
21+
#ifdef _WIN32
22+
#ifdef GEMMA_EXPORTS
23+
#define GEMMA_API __declspec(dllexport)
24+
#else
25+
#define GEMMA_API __declspec(dllimport)
26+
#endif
27+
#else
28+
#define GEMMA_API __attribute__((visibility("default")))
29+
#endif
30+
31+
#ifdef __cplusplus
32+
extern "C" {
33+
#endif
34+
35+
#ifdef __cplusplus
36+
typedef gcpp::GemmaContext GemmaContext;
37+
#else
38+
typedef struct GemmaContext GemmaContext;
39+
#endif
40+
41+
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
42+
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
43+
44+
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
45+
const char* model_type,
46+
const char* weights_path,
47+
const char* weight_type, int max_length);
48+
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
49+
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
50+
int max_length, GemmaTokenCallback callback,
51+
void* user_data);
52+
53+
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
54+
55+
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
56+
void* user_data);
57+
58+
#ifdef __cplusplus
59+
}
60+
#endif
61+
62+
#endif // THIRD_PARTY_GEMMA_C_API_H_

0 commit comments

Comments
 (0)