Skip to content

Minimalist example #1840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple

ifdef LLAMA_BUILD_SERVER
BUILD_TARGETS += server
Expand Down Expand Up @@ -276,6 +276,12 @@ main: examples/main/main.cpp build-info.h ggml.
@echo '==== Run ./main -h for help. ===='
@echo

simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./simple -h for help. ===='
@echo

quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

Expand Down
7 changes: 7 additions & 0 deletions examples/simple/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
set(TARGET simple)
add_executable(${TARGET} simple.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
177 changes: 177 additions & 0 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif

#include "common.h"
#include "llama.h"
#include "build-info.h"

#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <signal.h>
#endif



int main(int argc, char ** argv)
{
gpt_params params;

//---------------------------------
// Print help :
//---------------------------------

if ( argc == 1 || argv[1][0] == '-' )
{
printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] );
return 1 ;
}

//---------------------------------
// Load parameters :
//---------------------------------

if ( argc >= 2 )
{
params.model = argv[1];
}

if ( argc >= 3 )
{
params.prompt = argv[2];
}

if ( params.prompt.empty() )
{
params.prompt = "Hello my name is";
}

//---------------------------------
// Init LLM :
//---------------------------------

llama_init_backend();

llama_context * ctx ;

ctx = llama_init_from_gpt_params( params );

if ( ctx == NULL )
{
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
return 1;
}

//---------------------------------
// Tokenize the prompt :
//---------------------------------

std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize( ctx , params.prompt , true );

const int max_context_size = llama_n_ctx( ctx );
const int max_tokens_list_size = max_context_size - 4 ;

if ( (int)tokens_list.size() > max_tokens_list_size )
{
fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" ,
__func__ , (int)tokens_list.size() , max_tokens_list_size );
return 1;
}

fprintf( stderr, "\n\n" );

// Print the tokens from the prompt :

for( auto id : tokens_list )
{
printf( "%s" , llama_token_to_str( ctx , id ) );
}

fflush(stdout);


//---------------------------------
// Main prediction loop :
//---------------------------------

// The LLM keeps a contextual cache memory of previous token evaluation.
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
// example, we will just stop the loop once this cache is full or once an end of stream is detected.

while ( llama_get_kv_cache_token_count( ctx ) < max_context_size )
{
//---------------------------------
// Evaluate the tokens :
//---------------------------------

if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
{
fprintf( stderr, "%s : failed to eval\n" , __func__ );
return 1;
}

tokens_list.clear();

//---------------------------------
// Select the best prediction :
//---------------------------------

llama_token new_token_id = 0;

auto logits = llama_get_logits( ctx );
auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens)

std::vector<llama_token_data> candidates;
candidates.reserve( n_vocab );

for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ )
{
candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } );
}

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// Select it using the "Greedy sampling" method :
new_token_id = llama_sample_token_greedy( ctx , &candidates_p );


// is it an end of stream ?
if ( new_token_id == llama_token_eos() )
{
fprintf(stderr, " [end of text]\n");
break;
}

// Print the new token :
printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
fflush( stdout );

// Push this new token for next evaluation :
tokens_list.push_back( new_token_id );

} // wend of main loop

llama_free( ctx );

return 0;
}

// EOF