Source code for domainlab.algos.observers.c_obvisitor_gen
from domainlab.algos.observers.b_obvisitor import ObVisitor
from domainlab.utils.flows_gen_img_model import fun_gen
from domainlab.utils.logger import Logger
[docs]
class ObVisitorGen(ObVisitor):
"""
For Generative Models
"""
[docs]
def after_all(self):
super().after_all()
logger = Logger.get_logger()
logger.info("generating images for final model at last epoch")
fun_gen(
subfolder_na=self.host_trainer.model.visitor.model_name + "final",
args=self.host_trainer.aconf,
node=self.host_trainer.task,
model=self.host_trainer.model,
device=self.device,
)
logger.info("generating images for oracle model")
model_or = self.host_trainer.model.load(
"oracle"
) # @FIXME: name "oracle is a strong dependency
model_or = model_or.to(self.device)
model_or.eval()
fun_gen(
subfolder_na=self.host_trainer.model.visitor.model_name + "oracle",
args=self.host_trainer.aconf,
node=self.host_trainer.task,
model=model_or,
device=self.device,
)
logger.info("generating images for selected model")
model_ld = self.host_trainer.model.load()
model_ld = model_ld.to(self.device)
model_ld.eval()
fun_gen(
subfolder_na=self.host_trainer.model.visitor.model_name + "selected",
args=self.host_trainer.aconf,
node=self.host_trainer.task,
model=model_ld,
device=self.device,
)