Detect Stalled Training and Stop Training Job Using SageMaker Debugger Rule¶
This notebook shows you how to use the StalledTrainingRule built-in rule. This rule can take an action to stop your training job, when the rule detects an inactivity in your training job for a certain time period. This functionality helps you monitor the training job status and reduces redundant resource usage.
How the StalledTrainingRule Built-in Rule Works¶
Amazon Sagemaker Debugger captures tensors that you want to watch from training jobs on AWS Deep Learning Containers or your local machine. If you use one of the Debugger-integrated Deep Learning Containers, you don’t need to make any changes to your training script to use the functionality of built-in rules. For information about Debugger-supported SageMaker frameworks and versions, see Debugger-supported framework versions for zero script change.
If you want to run a training script that uses partially supported framework by Debugger or your own custom container, you need to manually register the Debugger hook to your training script. The smdebug library provides tools to help the hook registration, and the sample script provided in the src folder includes the hook registration code as comment lines. For more information about how to manually register the Debugger hooks for this case, see the training script at
./src/simple_stalled_training.py, and documentation at smdebug TensorFlow hook, smdebug PyTorch hook, smdebug MXNet hook, and smdebug XGBoost hook.
The Debugger StalledTrainingRule watches tensor updates from your training job. If the rule doesn’t find new tensors updated to the default S3 URI for a threshold period of time, it takes an action to trigger the StopTrainingJob API operation. The following code cells set up a SageMaker TensorFlow estimator with the Debugger StalledTrainingRule to watch the losses pre-built tensor collection.
Import SageMaker Python SDK¶
[ ]:
import sagemaker
from sagemaker.tensorflow import TensorFlow
print(sagemaker.__version__)
Import SageMaker Debugger classes for rule configuration¶
[ ]:
from sagemaker.debugger import Rule, CollectionConfig, rule_configs
Create a unique training job prefix¶
A unique prefix must be specified for StalledTrainingRule to identify the exact training job name that you want to monitor and stop when the rule triggers the stalled training job issue. If there are multiple training jobs sharing the same prefix, this rule may react to other training jobs. If the rule cannot find the exact training job name with a provided prefix, it falls back to safe mode and does not stop the training job. The rule evaluation process goes on in parallel while the
training jobs are running. If you want to access the rule job logs, you will later find how to get the information at Get a direct Amazon CloudWatch URL to find the current rule processing job log.
The following code cell includes: * a code line to create a unique base_job_name_prefix * a stalled training job rule configuration object * a SageMaker TensorFlow estimator configuration with the Debugger rules parameter to run the built-in rule
Note: Debugger collects loss tensors by default every 500 steps.
[ ]:
# Append current time to your training job name to generate a unique base_job_name_prefix
import time
base_job_name_prefix= 'smdebug-stalled-demo-' + str(int(time.time()))
# Configure a StalledTrainingRule rule parameter object
stalled_training_job_rule = [
Rule.sagemaker(
base_config=rule_configs.stalled_training_rule(),
rule_parameters={
"threshold": "120",
"stop_training_on_fire": "True",
"training_job_name_prefix": base_job_name_prefix
}
)
]
# Configure a SageMaker TensorFlow estimator
estimator = TensorFlow(
role=sagemaker.get_execution_role(),
base_job_name=base_job_name_prefix,
train_instance_count=1,
train_instance_type='ml.m5.4xlarge',
entry_point='src/simple_stalled_training.py', # This sample script forces the training job to sleep for 10 minutes
framework_version='1.15.0',
py_version='py3',
train_max_run=3600,
## Debugger-specific parameter
rules = stalled_training_job_rule
)
[ ]:
estimator.fit(wait=False)
Monitoring Training and Rule Evaluation Status¶
Once you execute the estimator.fit() API, SageMaker initiates a training job in the background, and Debugger initiates a StalledTrainingRule rule evaluation job in parallel. Because the training scripts has a few lines of code at the end to force a sleep mode for 10 minutes, the RuleEvaluationStatus for StalledTrainingRule will change to IssuesFound in 2 minutes after the sleep mode is on and trigger the StopTrainingJob API.
Print the training job name¶
The following cell outputs the training job name and its training status running in the background.
[ ]:
job_name = estimator.latest_training_job.name
print('Training job name: {}'.format(job_name))
client = estimator.sagemaker_session.sagemaker_client
description = client.describe_training_job(TrainingJobName=job_name)
Output the current job status and the rule evaluation status¶
The following cell tracks the status of training job until the SecondaryStatus changes to Stopped or Completed. While training, Debugger collects output tensors from the training job and monitors the training job with the rules.
[ ]:
import time
if description['TrainingJobStatus'] != 'Completed':
while description['SecondaryStatus'] not in {'Stopped', 'Completed'}:
description = client.describe_training_job(TrainingJobName=job_name)
primary_status = description['TrainingJobStatus']
secondary_status = description['SecondaryStatus']
print('Current job status: [PrimaryStatus: {}, SecondaryStatus: {}] | {} Rule Evaluation Status: {}'
.format(primary_status, secondary_status,
estimator.latest_training_job.rule_job_summary()[0]["RuleConfigurationName"],
estimator.latest_training_job.rule_job_summary()[0]["RuleEvaluationStatus"]
)
)
time.sleep(15)
### Get a direct Amazon CloudWatch URL to find the current rule processing job log
The following script returns a CloudWatch URL. Copy the URL and Paste it to a browser. This will directly lead you to the rule job log page.
[ ]:
# This utility gives the link to monitor the CW event
def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
"""Helper function to get the rule job name"""
return "{}-{}-{}".format(
training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
)
def _get_cw_url_for_rule_job(rule_job_name, region):
return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(region, region, rule_job_name)
def get_rule_jobs_cw_urls(estimator):
region = boto3.Session().region_name
training_job = estimator.latest_training_job
training_job_name = training_job.describe()["TrainingJobName"]
rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]
result={}
for status in rule_eval_statuses:
if status.get("RuleEvaluationJobArn", None) is not None:
rule_job_name = _get_rule_job_name(training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"])
result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(rule_job_name, region)
return result
print(
"The direct CloudWatch URL to the current rule job:",
get_rule_jobs_cw_urls(estimator)[estimator.latest_training_job.rule_job_summary()[0]["RuleConfigurationName"]]
)
Conclusion¶
This notebook showed how you can use the Debugger StalledTrainingRule built-in rule for your training job to take action on rule evaluation status changes. To find more information about Debugger, see Amazon SageMaker Debugger Developer Guide and the smdebug GitHub documentation.
[ ]: