@@ -750,6 +750,84 @@ def __init__(self, x): pass
750750 error = "X.__init__"
751751 )
752752
753+ @collect_cases
754+ def test_good_literal (self ) -> Iterator [Case ]:
755+ yield Case (
756+ stub = r"""
757+ from typing_extensions import Literal
758+
759+ import enum
760+ class Color(enum.Enum):
761+ RED: int
762+
763+ NUM: Literal[1]
764+ CHAR: Literal['a']
765+ FLAG: Literal[True]
766+ NON: Literal[None]
767+ BYT1: Literal[b'abc']
768+ BYT2: Literal[b'\x90']
769+ ENUM: Literal[Color.RED]
770+ """ ,
771+ runtime = r"""
772+ import enum
773+ class Color(enum.Enum):
774+ RED = 3
775+
776+ NUM = 1
777+ CHAR = 'a'
778+ NON = None
779+ FLAG = True
780+ BYT1 = b"abc"
781+ BYT2 = b'\x90'
782+ ENUM = Color.RED
783+ """ ,
784+ error = None ,
785+ )
786+
787+ @collect_cases
788+ def test_bad_literal (self ) -> Iterator [Case ]:
789+ yield Case ("from typing_extensions import Literal" , "" , None ) # dummy case
790+ yield Case (
791+ stub = "INT_FLOAT_MISMATCH: Literal[1]" ,
792+ runtime = "INT_FLOAT_MISMATCH = 1.0" ,
793+ error = "INT_FLOAT_MISMATCH" ,
794+ )
795+ yield Case (
796+ stub = "WRONG_INT: Literal[1]" ,
797+ runtime = "WRONG_INT = 2" ,
798+ error = "WRONG_INT" ,
799+ )
800+ yield Case (
801+ stub = "WRONG_STR: Literal['a']" ,
802+ runtime = "WRONG_STR = 'b'" ,
803+ error = "WRONG_STR" ,
804+ )
805+ yield Case (
806+ stub = "BYTES_STR_MISMATCH: Literal[b'value']" ,
807+ runtime = "BYTES_STR_MISMATCH = 'value'" ,
808+ error = "BYTES_STR_MISMATCH" ,
809+ )
810+ yield Case (
811+ stub = "STR_BYTES_MISMATCH: Literal['value']" ,
812+ runtime = "STR_BYTES_MISMATCH = b'value'" ,
813+ error = "STR_BYTES_MISMATCH" ,
814+ )
815+ yield Case (
816+ stub = "WRONG_BYTES: Literal[b'abc']" ,
817+ runtime = "WRONG_BYTES = b'xyz'" ,
818+ error = "WRONG_BYTES" ,
819+ )
820+ yield Case (
821+ stub = "WRONG_BOOL_1: Literal[True]" ,
822+ runtime = "WRONG_BOOL_1 = False" ,
823+ error = 'WRONG_BOOL_1' ,
824+ )
825+ yield Case (
826+ stub = "WRONG_BOOL_2: Literal[False]" ,
827+ runtime = "WRONG_BOOL_2 = True" ,
828+ error = 'WRONG_BOOL_2' ,
829+ )
830+
753831
754832def remove_color_code (s : str ) -> str :
755833 return re .sub ("\\ x1b.*?m" , "" , s ) # this works!
0 commit comments