diff --git a/llama_for_kobold.py b/llama_for_kobold.py index b895c1ea19818..6e341bee34a92 100644 --- a/llama_for_kobold.py +++ b/llama_for_kobold.py @@ -37,12 +37,12 @@ class generation_outputs(ctypes.Structure): handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs -def load_model(model_filename,batch_size=8,max_context_length=512,threads=4,n_parts_overwrite=-1): +def load_model(model_filename,batch_size=8,max_context_length=2048,n_parts_overwrite=-1): inputs = load_model_inputs() inputs.model_filename = model_filename.encode("UTF-8") inputs.batch_size = batch_size inputs.max_context_length = max_context_length - inputs.threads = threads + inputs.threads = os.cpu_count() inputs.n_parts_overwrite = n_parts_overwrite ret = handle.load_model(inputs) return ret @@ -74,7 +74,7 @@ def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1 # global vars global friendlymodelname friendlymodelname = "" -maxctx = 512 +maxctx = 2048 maxlen = 128 modelbusy = False port = 5001 @@ -265,7 +265,7 @@ def stop(self): mdl_nparts += 1 modelname = os.path.abspath(sys.argv[1]) print("Loading model: " + modelname) - loadok = load_model(modelname,24,maxctx,4,mdl_nparts) + loadok = load_model(modelname,24,maxctx,mdl_nparts) print("Load Model OK: " + str(loadok)) #friendlymodelname = Path(modelname).stem ### this wont work on local kobold api, so we must hardcode a known HF model name