diff --git a/Pipfile b/Pipfile index e62f30ccc7..e8f2a4fd7e 100644 --- a/Pipfile +++ b/Pipfile @@ -50,6 +50,7 @@ tiktoken = "==0.4.0" hf-transfer = "==0.1.3" peft = "==0.5.0" azure-storage-file-datalake = ">=12.12.0" +keyring = "==24.2.0" [dev-packages] black = "==23.7.0" diff --git a/Pipfile.lock b/Pipfile.lock index e2c4564ac5..91f450303d 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "20732924c37ba84a1dd1804d15edc2e3681c7168e1709738c906bca9618177cb" + "sha256": "a875a144c98f35ebb3b7ed678f401997eddf69905b933f57a49e17d099f3af22" }, "pipfile-spec": 6, "requires": { @@ -228,19 +228,19 @@ }, "boto3": { "hashes": [ - "sha256:4cd3e96900fb50bddc9f48007176c80d15396d08c5248b25a41220f3570e014f", - "sha256:c0211a3e830432851c73fa1e136b14dbb6d02b5c9a5e1272c557e63538620b88" + "sha256:4ee914266c9bed16978677a367fd05053d8dcaddcbe998c9df30787ab73f87aa", + "sha256:682abbd304e93e726163d7de7448c1bf88108c72cf6a23dceb6bba86fdc86dff" ], "index": "pypi", - "version": "==1.28.43" + "version": "==1.28.45" }, "botocore": { "hashes": [ - "sha256:b4a3a1fcf75011351e2b0d3eb991f51f8d44a375d3e065f907dac67db232fc97", - "sha256:d8b0c41c8c75d82f15fee57f7d54a852a99810faacbeb9d6f3f022558a2c330e" + "sha256:85ff64a0ac2705c4ba36268c3b2dbc1184062e9cf729a89dd66c2f54f730fc79", + "sha256:cceb150cff1d7f7a6faf655510a8384eb4505a33b430495fe1744d03a70dc66a" ], "markers": "python_version >= '3.7'", - "version": "==1.31.43" + "version": "==1.31.45" }, "bravado": { "hashes": [ @@ -788,11 +788,11 @@ }, "gitpython": { "hashes": [ - "sha256:9cbefbd1789a5fe9bcf621bb34d3f441f3a90c8461d377f84eda73e721d9b06b", - "sha256:c19b4292d7a1d3c0f653858db273ff8a6614100d1eb1528b014ec97286193c09" + "sha256:4bb0c2a6995e85064140d31a33289aa5dce80133a23d36fcd372d716c54d3ebf", + "sha256:8d22b5cfefd17c79914226982bb7851d6ade47545b1735a9d010a2a4c26d8388" ], "markers": "python_version >= '3.7'", - "version": "==3.1.35" + "version": "==3.1.36" }, "gputil": { "hashes": [ @@ -918,19 +918,19 @@ }, "httpcore": { "hashes": [ - "sha256:a6f30213335e34c1ade7be6ec7c47f19f50c56db36abef1a9dfa3815b1cb3888", - "sha256:c2789b767ddddfa2a5782e3199b2b7f6894540b17b16ec26b2c4d8e103510b87" + "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9", + "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced" ], - "markers": "python_version >= '3.7'", - "version": "==0.17.3" + "markers": "python_version >= '3.8'", + "version": "==0.18.0" }, "httpx": { "hashes": [ - "sha256:06781eb9ac53cde990577af654bd990a4949de37a28bdb4a230d434f3a30b9bd", - "sha256:5853a43053df830c20f8110c5e69fe44d035d850b2dfe795e196f00fdb774bdd" + "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100", + "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875" ], - "markers": "python_version >= '3.7'", - "version": "==0.24.1" + "markers": "python_version >= '3.8'", + "version": "==0.25.0" }, "huggingface-hub": { "hashes": [ @@ -948,6 +948,14 @@ "markers": "python_version >= '3.5'", "version": "==3.4" }, + "importlib-metadata": { + "hashes": [ + "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb", + "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743" + ], + "markers": "python_version < '3.12'", + "version": "==6.8.0" + }, "inquirer": { "hashes": [ "sha256:a7441fd74d06fcac4385218a1f5e8703f7a113f7944e01af47b8c58e84f95ce5", @@ -970,6 +978,22 @@ ], "version": "==20.11.0" }, + "jaraco.classes": { + "hashes": [ + "sha256:10afa92b6743f25c0cf5f37c6bb6e18e2c5bb84a16527ccfc0040ea377e7aaeb", + "sha256:c063dd08e89217cee02c8d5e5ec560f2c8ce6cdc2fcdc2e68f7b2e5547ed3621" + ], + "markers": "python_version >= '3.8'", + "version": "==3.3.0" + }, + "jeepney": { + "hashes": [ + "sha256:5efe48d255973902f6badc3ce55e2aa6c5c3b3bc642059ef3a91247bcfcc5806", + "sha256:c0a454ad016ca575060802ee4d590dd912e35c122fa04e70306de3d076cce755" + ], + "markers": "sys_platform == 'linux'", + "version": "==0.8.0" + }, "jinja2": { "hashes": [ "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852", @@ -1032,6 +1056,14 @@ "index": "pypi", "version": "==1.5.16" }, + "keyring": { + "hashes": [ + "sha256:4901caaf597bfd3bbd78c9a0c7c4c29fcd8310dab2cffefe749e916b6527acd6", + "sha256:ca0746a19ec421219f4d713f848fa297a661a8a8c1504867e55bfb5e09091509" + ], + "index": "pypi", + "version": "==24.2.0" + }, "lit": { "hashes": [ "sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a" @@ -1111,6 +1143,14 @@ ], "version": "==1.6" }, + "more-itertools": { + "hashes": [ + "sha256:626c369fa0eb37bac0291bce8259b332fd59ac792fa5497b59837309cd5b114a", + "sha256:64e0735fcfdc6f3464ea133afe8ea4483b1c5fe3a3d69852e6503b43a0b222e6" + ], + "markers": "python_version >= '3.8'", + "version": "==10.1.0" + }, "mpmath": { "hashes": [ "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", @@ -2046,6 +2086,14 @@ "markers": "python_version < '3.13' and python_version >= '3.9'", "version": "==1.11.2" }, + "secretstorage": { + "hashes": [ + "sha256:2403533ef369eca6d2ba81718576c5e0f564d5cca1b58f73a8b23e7d4eeebd77", + "sha256:f356e6628222568e3af06f2eba8df495efa13b3b63081dafd4f7d9a7b7bc9f99" + ], + "markers": "sys_platform == 'linux'", + "version": "==3.3.3" + }, "sentencepiece": { "hashes": [ "sha256:004e6a621d4bc88978eecb6ea7959264239a17b70f2cbc348033d8195c9808ec", @@ -2099,11 +2147,11 @@ }, "setuptools": { "hashes": [ - "sha256:00478ca80aeebeecb2f288d3206b0de568df5cd2b8fada1209843cc9a8d88a48", - "sha256:af3d5949030c3f493f550876b2fd1dd5ec66689c4ee5d5344f009746f71fd5a8" + "sha256:56ee14884fd8d0cd015411f4a13f40b4356775a0aefd9ebc1d3bfb9a1acb32f1", + "sha256:eff96148eb336377ab11beee0c73ed84f1709a40c0b870298b0d058828761bae" ], "markers": "python_version >= '3.8'", - "version": "==68.2.0" + "version": "==68.2.1" }, "simplejson": { "hashes": [ @@ -2561,11 +2609,11 @@ }, "websocket-client": { "hashes": [ - "sha256:53e95c826bf800c4c465f50093a8c4ff091c7327023b10bfaff40cf1ef170eaa", - "sha256:ce54f419dfae71f4bdba69ebe65bf7f0a93fe71bc009ad3a010aacc3eebad537" + "sha256:3aad25d31284266bcfcfd1fd8a743f63282305a364b8d0948a43bd606acc652f", + "sha256:6cfc30d051ebabb73a5fa246efdcc14c8fbebbd0330f8984ac3bb6d9edd2ad03" ], "markers": "python_version >= '3.8'", - "version": "==1.6.2" + "version": "==1.6.3" }, "xxhash": { "hashes": [ @@ -2752,6 +2800,14 @@ ], "markers": "python_version >= '3.7'", "version": "==1.9.2" + }, + "zipp": { + "hashes": [ + "sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0", + "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147" + ], + "markers": "python_version >= '3.8'", + "version": "==3.16.2" } }, "develop": { @@ -3100,11 +3156,11 @@ }, "pytest-html": { "hashes": [ - "sha256:3b473cc278272f8b5a34cd3bf10f88ac5fcb17cb5af22f9323514af00c310e64", - "sha256:79c4677ed6196417bf290d8b81f706342ae49f726f623728efa3f7dfff09f8eb" + "sha256:2d5e2863196a940607a477d747793c75f18d76fa72276c5d4db423a39665f55c", + "sha256:81383115dec5d182bf19c5dbae0bc2bebf7ad5c371fda477ebe9b7b054d2b63b" ], "index": "pypi", - "version": "==4.0.0" + "version": "==4.0.1" }, "pytest-metadata": { "hashes": [ diff --git a/README.md b/README.md index 1d955f0b09..d8f5c7deba 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Using CLI for fine-tuning LLMs: ## What's New +- [PR 364](https://github.com/h2oai/h2o-llmstudio/pull/364) User secrets are now handled more securely and flexible. Support for handling secrets using the 'keyring' library was added. User settings are tried to be migrated automatically. - [PR 328](https://github.com/h2oai/h2o-llmstudio/pull/328) RLHF is now a separate problem type. Note that starting a new RLHF experiment from an old experiment that used RLHF is no longer supported. To continue from a previous experiment, please start a new experiment and enter the settings from the previous experiment manually. - [PR 308](https://github.com/h2oai/h2o-llmstudio/pull/308) Sequence to sequence models have been added as a new problem type. - [PR 152](https://github.com/h2oai/h2o-llmstudio/pull/152) Add RLHF functionality for fine-tuning LLMs. diff --git a/llm_studio/app_utils/config.py b/llm_studio/app_utils/config.py index 64f16111e6..bb00ed725f 100644 --- a/llm_studio/app_utils/config.py +++ b/llm_studio/app_utils/config.py @@ -85,6 +85,7 @@ def get_size(x): ], "user_settings": { "theme_dark": True, + "credential_saver": ".env File", "default_aws_bucket_name": f"{os.getenv('AWS_BUCKET', 'bucket_name')}", "default_aws_access_key": os.getenv("AWS_ACCESS_KEY_ID", ""), "default_aws_secret_key": os.getenv("AWS_SECRET_ACCESS_KEY", ""), diff --git a/llm_studio/app_utils/handlers.py b/llm_studio/app_utils/handlers.py index 13fe0fdc7d..555690899b 100644 --- a/llm_studio/app_utils/handlers.py +++ b/llm_studio/app_utils/handlers.py @@ -42,11 +42,12 @@ list_current_experiments, ) from llm_studio.app_utils.sections.settings import settings -from llm_studio.app_utils.utils import ( - add_model_type, - load_user_settings, - save_user_settings, +from llm_studio.app_utils.setting_utils import ( + load_default_user_settings, + load_user_settings_and_secrets, + save_user_settings_and_secrets, ) +from llm_studio.app_utils.utils import add_model_type from llm_studio.app_utils.wave_utils import report_error, wave_utils_handle_error logger = logging.getLogger(__name__) @@ -77,13 +78,13 @@ async def handle(q: Q) -> None: await settings(q) elif q.args["save_settings"]: logger.info("Saving user settings") - save_user_settings(q) + await save_user_settings_and_secrets(q) await settings(q) elif q.args["load_settings"]: - load_user_settings(q) + load_user_settings_and_secrets(q) await settings(q) elif q.args["restore_default_settings"]: - load_user_settings(q, force_defaults=True) + load_default_user_settings(q) await settings(q) elif q.args["report_error"]: diff --git a/llm_studio/app_utils/initializers.py b/llm_studio/app_utils/initializers.py index 742c0cbe64..f86bcc1b5a 100644 --- a/llm_studio/app_utils/initializers.py +++ b/llm_studio/app_utils/initializers.py @@ -6,21 +6,20 @@ from bokeh.resources import Resources as BokehResources from h2o_wave import Q +from llm_studio.app_utils.config import default_cfg +from llm_studio.app_utils.db import Database, Dataset from llm_studio.app_utils.sections.common import interface -from llm_studio.src.utils.config_utils import load_config_py, save_config_yaml - -from .config import default_cfg -from .db import Database, Dataset -from .utils import ( +from llm_studio.app_utils.setting_utils import load_user_settings_and_secrets +from llm_studio.app_utils.utils import ( get_data_dir, get_database_dir, get_download_dir, get_output_dir, get_user_db_path, get_user_name, - load_user_settings, prepare_default_dataset, ) +from llm_studio.src.utils.config_utils import load_config_py, save_config_yaml logger = logging.getLogger(__name__) @@ -97,7 +96,7 @@ async def initialize_client(q: Q) -> None: import_data(q) - load_user_settings(q) + load_user_settings_and_secrets(q) await interface(q) diff --git a/llm_studio/app_utils/sections/settings.py b/llm_studio/app_utils/sections/settings.py index 8f65ee5866..6fa797680d 100644 --- a/llm_studio/app_utils/sections/settings.py +++ b/llm_studio/app_utils/sections/settings.py @@ -4,6 +4,7 @@ from h2o_wave import Q, ui from llm_studio.app_utils.sections.common import clean_dashboard +from llm_studio.app_utils.setting_utils import Secrets from llm_studio.src.loggers import Loggers @@ -24,6 +25,34 @@ async def settings(q: Q) -> None: 'Save settings persistently' button below. To reload \ the persistently saved settings, use the 'Load settings' button.", ), + ui.separator("Credential Storage"), + ui.inline( + items=[ + ui.label("Credential Handler", width=label_width), + ui.dropdown( + name="credential_saver", + value=q.client["credential_saver"], + choices=[ui.choice(name, name) for name in Secrets.names()], + trigger=False, + width="300px", + ), + ] + ), + ui.message_bar( + type="info", + text="""Method used to save credentials (passwords) \ + for 'Save settings persistently'. \ + The recommended approach for saving credentials (passwords) is to \ + use either Keyring or to avoid permanent storage \ + (requiring re-entry upon app restart). \ + Keyring will be disabled if it is not set up on the host machine. \ + Only resort to local .env if your machine's \ + accessibility is restricted to you.\n\ + When you select 'Save settings persistently', \ + credentials will be removed from all non-selected methods. \ + 'Restore Default Settings' will clear credentials from all methods. + """, + ), ui.separator("Appearance"), ui.inline( items=[ diff --git a/llm_studio/app_utils/setting_utils.py b/llm_studio/app_utils/setting_utils.py new file mode 100644 index 0000000000..e6d1482b71 --- /dev/null +++ b/llm_studio/app_utils/setting_utils.py @@ -0,0 +1,344 @@ +import errno +import functools +import logging +import os +import pickle +import signal +import traceback +from typing import Any, List + +import keyring +import yaml +from h2o_wave import Q, ui +from keyring.errors import KeyringLocked, PasswordDeleteError + +from llm_studio.app_utils.config import default_cfg +from llm_studio.app_utils.utils import get_database_dir, get_user_id + +__all__ = [ + "load_user_settings_and_secrets", + "load_default_user_settings", + "save_user_settings_and_secrets", + "Secrets", +] + +logger = logging.getLogger(__name__) +SECRET_KEYS = [ + key + for key in default_cfg.user_settings + if any(password in key for password in ["token", "key"]) +] +USER_SETTING_KEYS = [key for key in default_cfg.user_settings if key not in SECRET_KEYS] + + +async def save_user_settings_and_secrets(q: Q): + await _save_secrets(q) + _save_user_settings(q) + + +def load_user_settings_and_secrets(q: Q): + _maybe_migrate_to_yaml(q) + _load_secrets(q) + _load_user_settings(q) + + +def load_default_user_settings(q: Q): + for key in default_cfg.user_settings: + q.client[key] = default_cfg.user_settings[key] + _clear_secrets(q, key) + + +class NoSaver: + """ + Base class that provides methods for saving, loading, and deleting password entries. + + Attributes: + username (str): The username associated with the password entries. + root_dir (str): The root directory. + + Methods: + save(name: str, password: str) -> None: + Save a password entry with the given name and password. + + load(name: str) -> str: + Load and return the password associated with the given name. + + delete(name: str) -> None: + Delete the password entry with the given name. + + """ + + def __init__(self, username: str, root_dir: str): + self.username = username + self.root_dir = root_dir + + def save(self, name: str, password: str): + pass + + def load(self, name: str): + pass + + def delete(self, name: str): + pass + + +class KeyRingSaver(NoSaver): + """ + A class for saving, loading, and deleting passwords using the keyring library. + Some machines may not have keyring installed, so this class may not be available. + """ + + def __init__(self, username: str, root_dir: str): + super().__init__(username, root_dir) + self.namespace = f"{username}_h2o_llmstudio" + + def save(self, name: str, password: str): + keyring.set_password(self.namespace, name, password) + + def load(self, name: str): + return keyring.get_password(self.namespace, name) + + def delete(self, name: str): + try: + keyring.delete_password(self.namespace, name) + except (KeyringLocked, PasswordDeleteError): + pass + except Exception as e: + logger.warning(f"Error deleting password for keyring: {e}") + + +class EnvFileSaver(NoSaver): + """ + This module provides the EnvFileSaver class, which is used to save, load, + and delete name-password pairs in an environment file. + Only use this class if you are sure that the environment file is secure. + """ + + @property + def filename(self): + return os.path.join(self.root_dir, f"{self.username}.env") + + def save(self, name: str, password: str): + data = {} + if os.path.exists(self.filename): + with open(self.filename, "r") as f: + data = yaml.safe_load(f) + data[name] = password + with open(self.filename, "w") as f: + yaml.safe_dump(data, f) + + def load(self, name: str): + if not os.path.exists(self.filename): + return None + + with open(self.filename, "r") as f: + data = yaml.safe_load(f) + return data.get(name, None) + + def delete(self, name: str): + if os.path.exists(self.filename): + with open(self.filename, "r") as f: + data = yaml.safe_load(f) + if data and name in data: + del data[name] + with open(self.filename, "w") as f: + yaml.safe_dump(data, f) + + +# https://stackoverflow.com/questions/2281850/timeout-function-if-it-takes-too-long-to-finish +class TimeoutError(Exception): + pass + + +def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): + def decorator(func): + def _handle_timeout(signum, frame): + raise TimeoutError(error_message) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(seconds) + try: + result = func(*args, **kwargs) + finally: + signal.alarm(0) + return result + + return wrapper + + return decorator + + +@timeout(3) +def check_if_keyring_works(): + """ + Test if keyring is working. On misconfigured machines, + Keyring may hang up to 2 minutes with the following error: + jeepney.wrappers.DBusErrorResponse: + [org.freedesktop.DBus.Error.TimedOut] + ("Failed to activate service 'org.freedesktop.secrets': + timed out (service_start_timeout=120000ms)",) + + To avoid waiting for 2 minutes, we kill the process after 3 seconds. + """ + keyring.get_password("service", "username") + + +class Secrets: + """ + Factory class to get the secrets' handler. + """ + + _secrets = { + "Do not save credentials permanently": NoSaver, + ".env File": EnvFileSaver, + } + try: + check_if_keyring_works() + logger.info("Keyring is correctly configured on this machine.") + _secrets["Keyring"] = KeyRingSaver + except TimeoutError: + logger.warning( + "Error loading keyring due to timeout. Disabling keyring save option." + ) + except Exception as e: + logger.warning(f"Error loading keyring: {e}. Disabling keyring save option.") + + @classmethod + def names(cls) -> List[str]: + return sorted(cls._secrets.keys()) + + @classmethod + def get(cls, name: str) -> Any: + return cls._secrets.get(name) + + +def _save_user_settings(q: Q): + user_settings = {key: q.client[key] for key in USER_SETTING_KEYS} + with open(_get_usersettings_path(q), "w") as f: + yaml.dump(user_settings, f) + + +def _load_user_settings(q: Q): + if os.path.isfile(_get_usersettings_path(q)): + logger.info("Reading user settings") + with open(_get_usersettings_path(q), "r") as f: + user_settings = yaml.load(f, Loader=yaml.FullLoader) + for key in USER_SETTING_KEYS: + q.client[key] = user_settings.get(key, default_cfg.user_settings[key]) + + +async def _save_secrets(q: Q): + secret_name, secrets_handler = _get_secrets_handler(q) + for key in SECRET_KEYS: + try: + _clear_secrets(q, key, excludes=tuple(secret_name)) + if q.client[key]: + secrets_handler.save(key, q.client[key]) + + except Exception: + exception = str(traceback.format_exc()) + logger.error(f"Could not save password {key} to {secret_name}") + q.page["meta"].dialog = ui.dialog( + title="Could not save secrets. " + "Please choose another Credential Handler.", + name="secrets_error", + items=[ + ui.text( + f"The following error occurred when" + f" using {secret_name}: {exception}." + ), + ui.button( + name="settings/close_error_dialog", label="Close", primary=True + ), + ], + closable=True, + ) + q.client["keep_meta"] = True + await q.page.save() + break + else: # if no exception + # force dataset connector updated when the user decides to click on save + q.client["dataset/import/s3_bucket"] = q.client["default_aws_bucket_name"] + q.client["dataset/import/s3_access_key"] = q.client["default_aws_access_key"] + q.client["dataset/import/s3_secret_key"] = q.client["default_aws_secret_key"] + q.client["dataset/import/kaggle_access_key"] = q.client[ + "default_kaggle_username" + ] + q.client["dataset/import/kaggle_secret_key"] = q.client[ + "default_kaggle_secret_key" + ] + + +def _load_secrets(q: Q): + secret_name, secrets_handler = _get_secrets_handler(q) + for key in SECRET_KEYS: + try: + q.client[key] = secrets_handler.load(key) + except Exception: + logger.error(f"Could not load password {key} from {secret_name}") + + +def _get_secrets_handler(q: Q): + secret_name = ( + q.client["credential_saver"] or default_cfg.user_settings["credential_saver"] + ) + secrets_handler = Secrets.get(secret_name)( + username=get_user_id(q), root_dir=get_database_dir(q) + ) + return secret_name, secrets_handler + + +def _clear_secrets(q: Q, name: str, excludes=tuple()): + for secret_name in Secrets.names(): + if secret_name not in excludes: + secrets_handler = Secrets.get(secret_name)( + username=get_user_id(q), root_dir=get_database_dir(q) + ) + + secrets_handler.delete(name) + + +def _maybe_migrate_to_yaml(q: Q): + """ + Migrate user settings from a pickle file to a YAML file. + """ + # prior, we used to save the user settings in a pickle file + old_usersettings_path = os.path.join( + get_database_dir(q), f"{get_user_id(q)}.settings" + ) + if not os.path.isfile(old_usersettings_path): + return + + try: + with open(old_usersettings_path, "rb") as f: + user_settings = pickle.load(f) + + secret_name, secrets_handler = _get_secrets_handler(q) + logger.info(f"Migrating token using {secret_name}") + for key in SECRET_KEYS: + if key in user_settings: + secrets_handler.save(key, user_settings[key]) + + with open(_get_usersettings_path(q), "w") as f: + yaml.dump( + { + key: value + for key, value in user_settings.items() + if key in USER_SETTING_KEYS + }, + f, + ) + os.remove(old_usersettings_path) + logger.info(f"Successfully migrated tokens to {secret_name}. Old file deleted.") + except Exception as e: + logger.info( + f"Could not migrate tokens. " + f"Please delete {old_usersettings_path} and set your credentials again." + f"Error: \n\n {e} {traceback.format_exc()}" + ) + + +def _get_usersettings_path(q: Q): + return os.path.join(get_database_dir(q), f"{get_user_id(q)}.yaml") diff --git a/llm_studio/app_utils/utils.py b/llm_studio/app_utils/utils.py index fc02235ad6..1aabfd4d77 100644 --- a/llm_studio/app_utils/utils.py +++ b/llm_studio/app_utils/utils.py @@ -7,7 +7,6 @@ import logging import math import os -import pickle import random import re import shutil @@ -1833,42 +1832,6 @@ def check_valid_upload_content(upload_path: str) -> Tuple[bool, str]: return valid, error -def load_user_settings(q: Q, force_defaults: bool = False): - # get settings from settings pickle if it exists or set default values - if os.path.isfile(get_usersettings_path(q)) and not force_defaults: - logger.info("Reading settings") - with open(get_usersettings_path(q), "rb") as f: - user_settings = pickle.load(f) - for key in default_cfg.user_settings: - q.client[key] = user_settings.get(key, default_cfg.user_settings[key]) - else: - logger.info("Using default settings") - for key in default_cfg.user_settings: - q.client[key] = default_cfg.user_settings[key] - - -def save_user_settings(q: Q): - # Hacky way to get a dict of q.client key/value pairs - user_settings = {} - for key in default_cfg.user_settings: - user_settings.update({key: q.client[key]}) - - # force dataset connector updated when the user decides to click on save - q.client["dataset/import/s3_bucket"] = q.client["default_aws_bucket_name"] - q.client["dataset/import/s3_access_key"] = q.client["default_aws_access_key"] - q.client["dataset/import/s3_secret_key"] = q.client["default_aws_secret_key"] - - q.client["dataset/import/azure_conn_string"] = q.client["default_azure_conn_string"] - q.client["dataset/import/azure_container"] = q.client["default_azure_container"] - - q.client["dataset/import/kaggle_access_key"] = q.client["default_kaggle_username"] - q.client["dataset/import/kaggle_secret_key"] = q.client["default_kaggle_secret_key"] - - with open(get_usersettings_path(q), "wb") as f: - # slightly obfuscate to binary pickle file - pickle.dump(user_settings, f) - - def flatten_dict(d: collections.abc.MutableMapping) -> dict: """ Adapted from https://stackoverflow.com/a/6027615 diff --git a/requirements.txt b/requirements.txt index 26cb4628d1..02e6ca0632 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,8 +15,8 @@ bitsandbytes==0.41.1 bleach==6.0.0 ; python_version >= '3.7' blessed==1.20.0 ; python_version >= '2.7' bokeh==3.2.1 -boto3==1.28.43 -botocore==1.31.43 ; python_version >= '3.7' +boto3==1.28.45 +botocore==1.31.45 ; python_version >= '3.7' bravado==11.0.3 ; python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' and python_full_version != '3.5.0' bravado-core==6.1.0 ; python_version >= '3.7' certifi==2023.7.22 ; python_version >= '3.6' @@ -41,19 +41,22 @@ frozenlist==1.4.0 ; python_version >= '3.8' fsspec[http]==2023.6.0 ; python_version >= '3.8' future==0.18.3 ; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' gitdb==4.0.10 ; python_version >= '3.7' -gitpython==3.1.35 ; python_version >= '3.7' +gitpython==3.1.36 ; python_version >= '3.7' gputil==1.4.0 greenlet==2.0.2 ; python_version >= '3' and platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32'))))) h11==0.14.0 ; python_version >= '3.7' h2o-wave==0.26.2 hf-transfer==0.1.3 -httpcore==0.17.3 ; python_version >= '3.7' -httpx==0.24.1 ; python_version >= '3.7' +httpcore==0.18.0 ; python_version >= '3.8' +httpx==0.25.0 ; python_version >= '3.8' huggingface-hub==0.16.4 idna==3.4 ; python_version >= '3.5' +importlib-metadata==6.8.0 ; python_version < '3.12' inquirer==3.1.3 ; python_version >= '3.8' isodate==0.6.1 isoduration==20.11.0 +jaraco.classes==3.3.0 ; python_version >= '3.8' +jeepney==0.8.0 ; sys_platform == 'linux' jinja2==3.1.2 jmespath==1.0.1 ; python_version >= '3.7' joblib==1.3.2 ; python_version >= '3.7' @@ -62,9 +65,11 @@ jsonref==1.1.0 ; python_version >= '3.7' jsonschema==4.19.0 ; python_version >= '3.8' jsonschema-specifications==2023.7.1 ; python_version >= '3.8' kaggle==1.5.16 +keyring==24.2.0 lit==16.0.6 markupsafe==2.1.3 ; python_version >= '3.7' monotonic==1.6 +more-itertools==10.1.0 ; python_version >= '3.8' mpmath==1.3.0 msgpack==1.0.5 multidict==6.0.4 ; python_version >= '3.7' @@ -103,8 +108,9 @@ sacrebleu==2.0.0 safetensors==0.3.3 scikit-learn==1.3.0 scipy==1.11.2 ; python_version < '3.13' and python_version >= '3.9' +secretstorage==3.3.3 ; sys_platform == 'linux' sentencepiece==0.1.99 -setuptools==68.2.0 ; python_version >= '3.8' +setuptools==68.2.1 ; python_version >= '3.8' simplejson==3.19.1 ; python_version >= '2.5' and python_version not in '3.0, 3.1, 3.2, 3.3' six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' smmap==5.0.0 ; python_version >= '3.6' @@ -135,8 +141,9 @@ uvicorn==0.23.2 ; python_version >= '3.8' wcwidth==0.2.6 webcolors==1.13 webencodings==0.5.1 -websocket-client==1.6.2 ; python_version >= '3.8' +websocket-client==1.6.3 ; python_version >= '3.8' xxhash==3.3.0 ; python_version >= '3.7' xyzservices==2023.7.0 ; python_version >= '3.8' yarg==0.1.9 yarl==1.9.2 ; python_version >= '3.7' +zipp==3.16.2 ; python_version >= '3.8' diff --git a/tests/app_utils/utils/setting_utils.py b/tests/app_utils/utils/setting_utils.py new file mode 100644 index 0000000000..391f6e258d --- /dev/null +++ b/tests/app_utils/utils/setting_utils.py @@ -0,0 +1,52 @@ +from unittest import mock + +from app_utils.config import default_cfg +from app_utils.utils.setting_utils import ( + EnvFileSaver, + KeyRingSaver, + NoSaver, + Secrets, + load_default_user_settings, +) + + +def test_no_saver(): + saver = NoSaver("test_user", "/") + assert saver.save("name", "password") is None + assert saver.load("name") is None + assert saver.delete("name") is None + + +def test_keyring_saver(mocker): + mocker.patch("keyring.set_password") + mocker.patch("keyring.get_password", return_value="password") + mocker.patch("keyring.delete_password") + saver = KeyRingSaver("test_user", "/") + saver.save("name", "password") + assert saver.load("name") == "password" + saver.delete("name") + assert mocker.patch("keyring.delete_password").is_called + + +def test_env_file_saver(tmpdir): + saver = EnvFileSaver("test_user", str(tmpdir)) + saver.save("name", "password") + saver.save("name2", "password2") + assert saver.load("name") == "password" + saver.delete("name") + assert saver.load("name") is None + assert saver.load("name2") == "password2" + + +def test_secrets_get(): + assert isinstance(Secrets.get("Do not save credentials permanently"), type) + assert isinstance(Secrets.get("Keyring"), type) + assert isinstance(Secrets.get(".env File"), type) + + +def test_load_default_user_settings(mocker): + q = mock.MagicMock() + q.client = dict() + mocker.patch("app_utils.utils.setting_utils.clear_secrets", return_value=None) + load_default_user_settings(q) + assert set(q.client.keys()) == set(default_cfg.user_settings.keys())