-
Notifications
You must be signed in to change notification settings - Fork 1
Dev gan trainer #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Dev gan trainer #24
Conversation
…ce for complex orchestration that involves combining the context outputs from multiple forward groups.
…AN training process of updating of generator and discriminator separately.
…mprove readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I double checked and found the tests all pass, nice job. Given the great stuff going on with testing you might consider making testing a part of automated jobs with a GitHub Actions job. The benefit here is that it'd provide you and reviewers confidence about the changes within the context of the pull request (similar to how the pre-commit checks operate).
… inputs, targets, and preds.
… optimizer management
… forward groups derived context objects
…ndling and streamline output dimension calculations
…to use lowercase tuple syntax
…lemented to raising NotImplementedError
…od to raise KeyError
…to support a variety of forward function signatures in concrete loss classes, the old abstract AbstractLoss class has been replaced with a protocol and a non-abstract BaseLoss. Related modules and methods using the old abstract losses are also updated. And a new module is implemented to hourse brushed up wGAN training losses whereas the obsetele modules were removed.
… for inputs and outputs, enhancing type clarity and consistency.
…l instead of BaseGeneratorModel for consistency.
… for flexibility.
…flowLogger for optional dictionary parameters
…weights and add model registry for logging parameters
This PR adds back the Generative Adversarial Network (GAN) training functionality following its removal in refactor #19 where the model forward pass logic is decoupled from trainer and promoted as a separate engine subpackage (but #19 only included the engine parts for regular UNet training).
Introduces
src/virtual_stain_flow/engine/forward_groups.py:Added
DiscriminatorForwardGroup, a standardized interface for discriminator forward passessrc/virtual_stain_flow/engine/orchestrators.py:An additional layer of abstraction.
OrchestratedStepandGANOrchestratorbetween the forward passing engine and the trainer classes to define complex orchestrations involving more than one model (neural network), as needed when training GANs.Specifically,
GANOrchestratorprovides training interface to separately train the generator and discriminatorhence the orchestrator abstract helps makes things cleaner.
New operations for Context class
Minimal tests