@@ -766,25 +766,27 @@ def create_firewall_rule(args, gcloud_compute, network_name, rule_name):
766766 Raises:
767767 subprocess.CalledProcessError: If the `gcloud` command fails
768768 """
769+ firewall_args = get_firewall_args (args , network_name )
769770 if utils .print_info_messages (args ):
770771 print ('Creating the firewall rule {0}' .format (rule_name ))
771772 create_cmd = [
772773 'firewall-rules' , 'create' , rule_name ,
773774 '--allow' , 'tcp:22' ,
774775 '--network' , network_name ,
775776 '--description' , _DATALAB_FIREWALL_RULE_DESCRIPTION ]
776- utils .call_gcloud_quietly (args , gcloud_compute , create_cmd )
777+ utils .call_gcloud_quietly (firewall_args , gcloud_compute , create_cmd )
777778 return
778779
779780
780781def has_unexpected_firewall_rules (args , gcloud_compute , network_name ):
781- rule_name = _DATALAB_FIREWALL_RULE_TEMPLATE .format (network_name )
782+ rule_name = generate_firewall_rule_name (network_name )
783+ firewall_args = get_firewall_args (args , network_name )
782784 list_cmd = [
783785 'firewall-rules' , 'list' ,
784786 '--filter' , 'network~.^*{0}$' .format (network_name ),
785787 '--format' , 'value(name)' ]
786788 with tempfile .TemporaryFile () as tf :
787- gcloud_compute (args , list_cmd , stdout = tf )
789+ gcloud_compute (firewall_args , list_cmd , stdout = tf )
788790 tf .seek (0 )
789791 matching_rules = tf .read ().decode ('utf-8' ).strip ()
790792 if matching_rules and (matching_rules != rule_name ):
@@ -813,17 +815,39 @@ def ensure_firewall_rule_exists(args, gcloud_compute, network_name):
813815 Raises:
814816 subprocess.CalledProcessError: If the `gcloud` command fails
815817 """
816- rule_name = _DATALAB_FIREWALL_RULE_TEMPLATE .format (network_name )
818+ firewall_args = get_firewall_args (args , network_name )
819+ rule_name = generate_firewall_rule_name (network_name )
817820 get_cmd = [
818821 'firewall-rules' , 'describe' , rule_name , '--format' , 'value(name)' ]
819822 try :
820823 utils .call_gcloud_quietly (
821- args , gcloud_compute , get_cmd , report_errors = False )
824+ firewall_args , gcloud_compute , get_cmd , report_errors = False )
822825 except subprocess .CalledProcessError :
823826 create_firewall_rule (args , gcloud_compute , network_name , rule_name )
824827 return
825828
826829
830+ def generate_firewall_rule_name (network_name ):
831+ """Converts network name to a valid rule name to support shared vpc"""
832+ if "/" in network_name :
833+ return _DATALAB_FIREWALL_RULE_TEMPLATE .format (
834+ network_name .split ("/" )[- 1 ])
835+ else :
836+ return _DATALAB_FIREWALL_RULE_TEMPLATE .format (network_name )
837+
838+
839+ def get_firewall_args (args , network_name ):
840+ """
841+ Shared VPCs firewall rules need to be created in the host project.
842+ This modifies the args to the host project for commands that need it.
843+ """
844+ if "/" in network_name :
845+ project_name = network_name .split ("/" )[1 ]
846+ args .project = project_name
847+
848+ return args
849+
850+
827851def create_disk (args , gcloud_compute , disk_name ):
828852 """Create the user's persistent disk.
829853
0 commit comments