27
27
from .encoding import to_bytes
28
28
from .solve import (
29
29
solve , Fail , OracleFunc , ResultType ,
30
- convert_to_bytes , remove_padding )
30
+ convert_to_bytes , remove_padding , add_padding )
31
31
32
32
__all__ = [
33
33
'padding_oracle' ,
34
34
]
35
35
36
36
37
- def padding_oracle (ciphertext : Union [bytes , str ],
37
+ def padding_oracle (payload : Union [bytes , str ],
38
38
block_size : int ,
39
39
oracle : OracleFunc ,
40
40
num_threads : int = 1 ,
41
41
log_level : int = logging .INFO ,
42
42
null_byte : bytes = b' ' ,
43
43
return_raw : bool = False ,
44
+ mode : Union [bool , str ] = 'decrypt' ,
45
+ pad_payload : bool = True
44
46
) -> Union [bytes , List [int ]]:
45
47
'''
46
48
Run padding oracle attack to decrypt ciphertext given a function to check
47
49
wether the ciphertext can be decrypted successfully.
48
50
49
51
Args:
50
- ciphertext (bytes|str) the ciphertext you want to decrypt
52
+ payload (bytes|str) the payload you want to encrypt/ decrypt
51
53
block_size (int) block size (the ciphertext length should be
52
54
multiple of this)
53
55
oracle (function) a function: oracle(ciphertext: bytes) -> bool
@@ -58,33 +60,49 @@ def padding_oracle(ciphertext: Union[bytes, str],
58
60
set (default: None)
59
61
return_raw (bool) do not convert plaintext into bytes and
60
62
unpad (default: False)
63
+ mode (str) encrypt the payload (defaut: 'decrypt')
64
+ pad_payload (bool) PKCS#7 pad the supplied payload before
65
+ encryption (default: True)
66
+
61
67
62
68
Returns:
63
- plaintext (bytes|List[int]) the decrypted plaintext
69
+ result (bytes|List[int]) the processed payload
64
70
'''
65
71
66
72
# Check args
67
73
if not callable (oracle ):
68
74
raise TypeError ('the oracle function should be callable' )
69
- if not isinstance (ciphertext , (bytes , str )):
70
- raise TypeError ('ciphertext should have type bytes' )
75
+ if not isinstance (payload , (bytes , str )):
76
+ raise TypeError ('payload should have type bytes' )
71
77
if not isinstance (block_size , int ):
72
78
raise TypeError ('block_size should have type int' )
73
- if not len (ciphertext ) % block_size == 0 :
74
- raise ValueError ('ciphertext length should be multiple of block size' )
75
79
if not 1 <= num_threads <= 1000 :
76
80
raise ValueError ('num_threads should be in [1, 1000]' )
77
81
if not isinstance (null_byte , (bytes , str )):
78
82
raise TypeError ('expect null with type bytes or str' )
79
83
if not len (null_byte ) == 1 :
80
84
raise ValueError ('null byte should have length of 1' )
81
-
85
+ if not isinstance (mode , str ):
86
+ raise TypeError ('expect mode with type str' )
87
+ if isinstance (mode , str ) and mode not in ('encrypt' , 'decrypt' ):
88
+ raise ValueError ('mode must be either encrypt or decrypt' )
89
+ if (mode == 'decrypt' ) and not (len (payload ) % block_size == 0 ):
90
+ raise ValueError ('for decryption payload length should be multiple of block size' )
82
91
logger = get_logger ()
83
92
logger .setLevel (log_level )
84
93
85
- ciphertext = to_bytes (ciphertext )
94
+ payload = to_bytes (payload )
86
95
null_byte = to_bytes (null_byte )
87
96
97
+ # Does the user want the encryption routine
98
+ if (mode == 'encrypt' ):
99
+ return encrypt (payload , block_size , oracle , num_threads , null_byte , pad_payload , logger )
100
+
101
+ # If not continue with decryption as normal
102
+ return decrypt (payload , block_size , oracle , num_threads , null_byte , return_raw , logger )
103
+
104
+
105
+ def decrypt (payload , block_size , oracle , num_threads , null_byte , return_raw , logger ):
88
106
# Wrapper to handle exceptions from the oracle function
89
107
def wrapped_oracle (ciphertext : bytes ):
90
108
try :
@@ -105,7 +123,7 @@ def plaintext_callback(plaintext: bytes):
105
123
plaintext = convert_to_bytes (plaintext , null_byte )
106
124
logger .info (f'plaintext: { plaintext } ' )
107
125
108
- plaintext = solve (ciphertext , block_size , wrapped_oracle , num_threads ,
126
+ plaintext = solve (payload , block_size , wrapped_oracle , num_threads ,
109
127
result_callback , plaintext_callback )
110
128
111
129
if not return_raw :
@@ -115,6 +133,61 @@ def plaintext_callback(plaintext: bytes):
115
133
return plaintext
116
134
117
135
136
+ def encrypt (payload , block_size , oracle , num_threads , null_byte , pad_payload , logger ):
137
+ # Wrapper to handle exceptions from the oracle function
138
+ def wrapped_oracle (ciphertext : bytes ):
139
+ try :
140
+ return oracle (ciphertext )
141
+ except Exception as e :
142
+ logger .error (f'error in oracle with { ciphertext !r} , { e } ' )
143
+ logger .debug ('error details: {}' .format (traceback .format_exc ()))
144
+ return False
145
+
146
+ def result_callback (result : ResultType ):
147
+ if isinstance (result , Fail ):
148
+ if result .is_critical :
149
+ logger .critical (result .message )
150
+ else :
151
+ logger .error (result .message )
152
+
153
+ def plaintext_callback (plaintext : bytes ):
154
+ plaintext = convert_to_bytes (plaintext , null_byte ).strip (null_byte )
155
+ bytes_done = str (len (plaintext )).rjust (len (str (block_size )), ' ' )
156
+ blocks_done = solve_index .rjust (len (block_total ), ' ' )
157
+ printout = "{0}/{1} bytes encrypted in block {2}/{3}" .format (bytes_done , block_size , blocks_done , block_total )
158
+ logger .info (printout )
159
+
160
+ def blocks (data : bytes ):
161
+ return [data [index :(index + block_size )] for index in range (0 , len (data ), block_size )]
162
+
163
+ def bytes_xor (byte_string_1 : bytes , byte_string_2 : bytes ):
164
+ return bytes ([_a ^ _b for _a , _b in zip (byte_string_1 , byte_string_2 )])
165
+
166
+ if pad_payload :
167
+ payload = add_padding (payload , block_size )
168
+
169
+ if len (payload ) % block_size != 0 :
170
+ raise ValueError ('''For encryption payload length must be a multiple of blocksize. Perhaps you meant to
171
+ pad the payload (inbuilt PKCS#7 padding can be enabled by setting pad_payload=True)''' )
172
+
173
+ plaintext_blocks = blocks (payload )
174
+ ciphertext_blocks = [null_byte * block_size for _ in range (len (plaintext_blocks )+ 1 )]
175
+
176
+ solve_index = '1'
177
+ block_total = str (len (plaintext_blocks ))
178
+
179
+ for index in range (len (plaintext_blocks )- 1 , - 1 , - 1 ):
180
+ plaintext = solve (b'\x00 ' * block_size + ciphertext_blocks [index + 1 ], block_size , wrapped_oracle ,
181
+ num_threads , result_callback , plaintext_callback )
182
+ ciphertext_blocks [index ] = bytes_xor (plaintext_blocks [index ], plaintext )
183
+ solve_index = str (int (solve_index )+ 1 )
184
+
185
+ ciphertext = b'' .join (ciphertext_blocks )
186
+ logger .info (f"forged ciphertext: { ciphertext } " )
187
+
188
+ return ciphertext
189
+
190
+
118
191
def get_logger ():
119
192
logger = logging .getLogger ('padding_oracle' )
120
193
formatter = logging .Formatter ('[%(asctime)s][%(levelname)s] %(message)s' )
0 commit comments