wwliu
05/16/2022, 9:04 PMclass ModelTrackingHooks:
@hook_impl
def after_node_run(self, node: Node, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> None:
if node._func_name == "train_model":
model = outputs["example_model"]
mlflow.sklearn.log_model(model, "model")
mlflow.log_params(inputs["parameters"])
My question is, I only need to log metrics in this specific train_model
node, while based on my understanding, this function will run every time a node finishes, and there could be a lot of nodes in the whole pipeline. Is there way I could specify which node this hook is hooked to?