Skip to content

Conversation

@cdoern
Copy link
Contributor

@cdoern cdoern commented May 27, 2025

Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization into classes. This commit introduces the Model class

NOTE: a follow up to this work will be to introduce classes/structure for the DataLoader, Sampler, etc. This was left out of this PR given the already large scope of change.

The Model class wraps the various AutoModel classes we support -- and aims to be a lightweight wrapper to help with usability of the library with different model types. setup_optimizer resides within the model class and returns one of the optimizer types we support

These classes are one of a few steps needed to "SDK-ify" the training library

Adding structure to code via classes can either be someone's favorite or least favorite thing. So I figured I'd explain myself before continuing. Here is my rationale:

Classes provide logical structuring to code, especially code meant to be a publicly consumable SDK and allows you to associate related objects and methods with one another.

Being able to group functionality under the Model, Accelerator, and Checkpointer classes inherently reduces code complexity and duplication. Being able to store things like , self.distributed_framework,self.lora_config, etc in a way such that within the class they are accessible within different methods allows the arguments per method to go down drastically, as well as complex return values. Simpler methods and argument/return values allows for simpler testing of code.

@mergify mergify bot added testing Relates to testing ci-failure and removed ci-failure labels May 27, 2025
class ModelTypes(Enum):
LIGER = "Liger"
CAUSALLM = "Causallm"
DOLOMITE = "Dolomite"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've dropped dolomite, no need to include this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RobotSail Interesting! What does it mean exactly? If I grep through the code, I still see hits for dolomite, including the mandatory dependency on instructlab-dolomite. Was some decision made to drop it? Should we clean these remnants from the tree then?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being worked on in #589

@mergify
Copy link
Contributor

mergify bot commented May 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. @cdoern please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@booxter booxter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't reviewed tests or Accelerator class in detail. I need to step off this PR. Posting questions and concerns I have collected so far.

parser.add_argument(
"--model-class",
type=str,
default=ModelTypes.CAUSALLM.value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use choice=[x.value for x in enum] to avoid listing them below

sharding_strategy: ShardingStrategies = ShardingStrategies.HYBRID_SHARD


class Optimizers(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action required, Observation) I think it's more common to call enums as singular, not plural. But it's a matter of habit of course.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to singular

from deepspeed.ops.adam import DeepSpeedCPUAdam
except ImportError:
DeepSpeedCPUAdam = None
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action required) I know it was done in main_ds so you are not introducing anything new here, but consider not running code / issuing warnings when importing the module. An import should not, generally, produce side effects of this sort, especially in a library. Consider warning later when the missing class is actually referred to / used.

output_dir: str,
distributed_framework: DistributedBackend,
model_type: ModelTypes,
noise_alpha: Optional[float],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use type | None instead of Optional

)
self.model.config.eos_token_id = self.tokenizer.eos_token_id

if "ForCausalLM" not in self.model.__class__.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fragile; can you think of a more robust way of checking it? if not, maybe the Model class could have a helper method to hide the check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is inherited from main:

if "ForCausalLM" not in model.__class__.__name__:

I will refactor into a helper and we can investigate a better solution if there is one

from .utils import add_noisy_embeddings, convert_loss_to_reduce_sum

self.model = convert_loss_to_reduce_sum(
self.model, use_dolomite=(self.model_type == "dolomite")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrect enum == str check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed with children classes I created, I think

"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action required) Could be:

if ...:
     return True
if ...:
     return True
if ...:
     return True
return False

@cdoern
Copy link
Contributor Author

cdoern commented May 30, 2025

@booxter thanks for the review. I actually meant to remove Accelerator in this PR which is why there is a confusing non-usage of that class. I am intending to introduce it in a 2/n PR just for clarity.

In regard to most other comments, a lot of them are inherited from the existing code or mis-steps by me when splitting out my mega PR (I forgot to take my changes from utils.py for example). Will take another pass here. Thanks!

@mergify
Copy link
Contributor

mergify bot commented Jun 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. @cdoern please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 2, 2025
@mergify mergify bot removed the needs-rebase label Jun 3, 2025
@github-actions
Copy link

github-actions bot commented Jun 3, 2025

E2E (NVIDIA L40S x4) workflow launched on this PR: View run

@mergify mergify bot added the ci-failure label Jun 3, 2025
@mergify mergify bot removed the ci-failure label Jun 3, 2025
@mergify mergify bot added the ci-failure label Jun 3, 2025
@github-actions
Copy link

github-actions bot commented Jun 3, 2025

e2e workflow failed on this PR: View run, please investigate.

@mergify mergify bot removed the ci-failure label Jun 3, 2025
@github-actions
Copy link

github-actions bot commented Jun 3, 2025

e2e workflow succeeded on this PR: View run, congrats!

Copy link
Contributor

@booxter booxter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bnb question should be addressed before merging. Do we need it? Is it ok to drop it here?

base_model_args = {
"pretrained_model_name_or_path": args.model_name_or_path,
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an answer to this? Should the drop be included here?


self.reconcile_tokenizer()
if self.lora_config:
# First Party
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

@cdoern cdoern requested a review from booxter June 4, 2025 12:51
@github-actions
Copy link

github-actions bot commented Jun 4, 2025

E2E (NVIDIA L40S x4) workflow launched on this PR: View run

@cdoern
Copy link
Contributor Author

cdoern commented Jun 4, 2025

I changed model.parameters to a property so I need to remove the .parameters() refs. tests should pass now

Copy link
Contributor

@booxter booxter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable. It's hard to review a large patch line by line through multiple iterations, so this follow-up review focused on high level question of: whether prior feedback of mine was addressed. I think it was (bnb restored; logging module used, duplicate functions cleaned up; accelerator class removed; etc.)

@mergify mergify bot added the one-approval label Jun 4, 2025
@github-actions
Copy link

github-actions bot commented Jun 4, 2025

E2E (NVIDIA L40S x4) workflow launched on this PR: View run

@mergify mergify bot added the ci-failure label Jun 4, 2025
Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization
into classes. This commit introduces the Model class

NOTE: a follow up to this work will be to introduce classes/structure for the DataLoader, Sampler, etc. This was left out of this PR given the already large scope of change.

The Model class wraps the various AutoModel classes we support -- and aims to be a lightweight wrapper to help with usability of the library with different model types.
setup_optimizer resides within the model class and returns one of the optimizer types we support

These classes are one of a few steps needed to "SDK-ify" the training library

Adding structure to code via classes can either be someone's favorite or least favorite thing. So I figured I'd explain myself before continuing. Here is my rationale:

Classes provide logical structuring to code, especially code meant to be a publicly consumable SDK and allows you to associate related objects and methods with one another.

Being able to group functionality under the Model, Accelerator, and Checkpointer classes inherently reduces code complexity and duplication. Being able to store things like , self.distributed_framework,self.lora_config, etc in a way such that within the class they are accessible within different methods allows the arguments per method to go down drastically, as well as complex return values. Simpler methods and argument/return values allows for simpler testing of code.

Signed-off-by: Charlie Doern <[email protected]>
@mergify mergify bot removed the ci-failure label Jun 4, 2025
@cdoern
Copy link
Contributor Author

cdoern commented Jun 4, 2025

model.parameters cannot be a property because accelerate expects it to be a method: https://github.com/instructlab/training/actions/runs/15443149090/job/43465937222

@github-actions
Copy link

github-actions bot commented Jun 4, 2025

E2E (NVIDIA L40S x4) workflow launched on this PR: View run

lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
):
self.lora_config = lora_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think lora_config should not be put inside the model class, it should act as a wrapper to our model. We can deliberate this in a further issue/pr

@cdoern cdoern added the hold label Jun 4, 2025
@cdoern
Copy link
Contributor Author

cdoern commented Jun 4, 2025

holding for the L40s test to pass

Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I support moving quickly with these prs so that we can start to refine the final shape of the new sdk style codebase.

This is reasonable for now, pending future prs to update the other components.

@mergify mergify bot removed the one-approval label Jun 4, 2025
@github-actions
Copy link

github-actions bot commented Jun 4, 2025

e2e workflow succeeded on this PR: View run, congrats!

@cdoern cdoern removed the hold label Jun 4, 2025
@mergify mergify bot merged commit e78908c into instructlab:main Jun 4, 2025
18 checks passed
@JamesKunstle JamesKunstle mentioned this pull request Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

testing Relates to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants