Smelter - TAMUctf - Walkthrough
Break a weak RSA cookie encryption system and retrieve the flag.
This is a Write-Up to another crypto challenge of the TAMUctf. The challenge description is quite informative:
I’ve added 2048 bits of security, so there’s no way you can forge anything in this smelter.
You are provided with some functions which are using RSA, for simplicity I only will show the important sections of the code.
Firstly the check_session()
function this function essentially checks if there is a cookie and if not it will set the cookie. Also if there is a cookie the cookie is checked if its valid and the username matches.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def check_session(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# reset cookie response
response = redirect(url_for("index"))
response.set_cookie("smelter-session", DEFAULT_SESSION)
_session = request.cookies.get("smelter-session")
if not _session:
return response
data = decode_session(_session)
if not data:
return response
username, signature = data.get("username"), data.get("signature")
if not username or not signature:
return response
if not verify(username.encode(), b64decode(signature)):
return response
session["username"] = data["username"]
return f(*args, **kwargs)
return decorated_function
To solve the challenge we need the username admin:
1
2
if username == "admin":
return render_template("index.html", username=username, public_key=public_key(), flag=flag)
The next important functions are the verify()
function which is called after the decode_session()
function which just decodes the base64 to a object.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from Crypto.Util.number import getPrime, bytes_to_long, long_to_bytes
from Crypto.PublicKey import RSA
from hashlib import sha256
from src.util import encode, decode
e = 3
p, q = getPrime(1024), getPrime(1024)
while True:
n = p * q
phi = (p - 1) * (q - 1)
try:
d = pow(e, -1, phi)
break
except:
p, q = getPrime(1024), getPrime(1024)
privkey = RSA.construct((n, e, d))
pubkey = privkey.publickey()
def public_key() -> str:
return pubkey.export_key(format='PEM').decode()
def sign(message: bytes) -> bytes:
h = sha256(message).digest()
h = encode(h)
return long_to_bytes(pow(bytes_to_long(h), d, n))
def verify(message: bytes, signature: bytes) -> bool:
h = sha256(message).digest()
signature = bytes_to_long(signature)
signature = pow(signature, e, n)
signature = long_to_bytes(signature, 256)
print(signature)
signed_h = decode(signature)
return h == signed_h
The primes for the RSA signature for the checker are generated by two 1024 bit primes, the implementation itself looks fine, but you may notice the low exponent which will get handy.
The verify()
function simply converts the bytes to an int and then calculates the signature. Then it checks if it matches the needed signature. An important function is the decode()
function after the check this function looks self implemented and dangerous.
1
2
3
4
5
6
7
8
9
10
11
12
13
def decode(data):
try:
data = data[3:]
data = data[data.index(b"\x00") + 1:]
obj, _ = decoder.decode(data, asn1Spec=DigestInfo())
assert obj['digestAlgorithm']['algorithm'].asTuple() == SHA256_OID
assert isinstance(obj['digestAlgorithm']['parameters'], univ.Null)
assert isinstance(obj['digest'], univ.OctetString)
return obj['digest'].asOctets()
except:
# clearly you don't have a real signature
return None
You may notice that everything before the first zero byte is cut away, and then only the other part is used to get the real signature using the ASN1 standard. If you look at the encode function you see what the purpose of the first part should be.
1
2
3
4
5
6
7
8
9
def encode(hash):
obj = DigestInfo()
obj['digestAlgorithm'] = SHA256_ALGORITHM
obj['digest'] = hash
enc = encoder.encode(obj)
padding_len = 256 - len(enc) - 3
if padding_len < 0:
raise Exception("hash too long")
return b"\x00\x01" + b"\xff" * (256 - len(enc) - 3) + b"\x00" + enc
The first part of the signature should be padding and the padding is 0xff
’s. If you notice we can use anything what we want as padding, it just shouldn’t be 0x00
bytes, then everything is fine.
Now how we can use our observations?
Because the e is 3 and we have 1024 bit of n we can try to find a cubic root which has the ending of the desired output from the encoder.encode(obj)
function. With that we may be able to find this root. One important aspect is that we can’t have a zero byte in the numbers before the root, also the root^3 must be a number which nearly is huger than n, so that we don’t have zero bytes in the beginning.
To implement that we first need to calculate the result of the admin
username as the format specified in the script, then we need to find a number which has a cubic root and matches our criteria. The fancy number calculation to find a root I just copied from already existing work, but if you’d like to read into that topic you can search for Hensel's Lemma roots
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from pyasn1.codec.der import encoder
from pyasn1.type import univ, namedtype
from hashlib import sha256
from Crypto.Util.number import bytes_to_long, long_to_bytes
from sympy import integer_nthroot
from base64 import b64encode,b64decode
import json
class AlgorithmIdentifier(univ.Sequence):
componentType = namedtype.NamedTypes(
namedtype.NamedType('algorithm', univ.ObjectIdentifier()),
namedtype.NamedType('parameters', univ.Null())
)
class DigestInfo(univ.Sequence):
componentType = namedtype.NamedTypes(
namedtype.NamedType('digestAlgorithm', AlgorithmIdentifier()),
namedtype.NamedType('digest', univ.OctetString())
)
# we only use sha256 so no need to worry about other algorithms
SHA256_OID = (2, 16, 840, 1, 101, 3, 4, 2, 1)
SHA256_ALGORITHM = AlgorithmIdentifier()
SHA256_ALGORITHM['algorithm'] = SHA256_OID
SHA256_ALGORITHM['parameters'] = univ.Null()
def encode(hash):
obj = DigestInfo()
obj['digestAlgorithm'] = SHA256_ALGORITHM
obj['digest'] = hash
enc = encoder.encode(obj)
padding_len = 256 - len(enc) - 3
if padding_len < 0:
raise Exception("hash too long")
return b"\x00" + enc
def sign(message: bytes) -> bytes:
h = sha256(message).digest()
h = encode(h)
return h
def find_cube_root_with_suffix(suffix_bytes, max_additional_bytes=4):
"""
Finds possible values of `d` such that `d` has a perfect cube root and ends with the given suffix bytes.
Args:
suffix_bytes (bytes): The last `k` bytes of `d` (e.g., b'\x8B' or b'\x37\xA2').
max_additional_bytes (int): Maximum number of additional bytes to consider (default: 4).
Returns:
list[int]: List of possible `d` values (smallest solutions).
"""
if not suffix_bytes:
return []
k = len(suffix_bytes)
C = int.from_bytes(suffix_bytes, byteorder='big')
modulus = 256 ** k
# Check if C is odd or even
if C % 2 == 1:
# Odd case: Solve x^3 ≡ C mod 256^k
solutions = solve_cubic_congruence_odd(C, k)
else:
# Even case: Factor C = 2^(3r) * c, solve for c (odd)
r = 0
c = C
while c % 2 == 0:
c //= 2
r += 1
if r % 3 != 0:
return [] # No solution if exponent of 2 isn't divisible by 3
r = r // 3
# Solve z^3 ≡ c mod 256^(k - r)
c_modulus = 256 ** (k - r)
z_solutions = solve_cubic_congruence_odd(c % c_modulus, k - r)
if not z_solutions:
return []
# Construct x = 2^r * z mod 256^k
solutions = [(2 ** r) * z % modulus for z in z_solutions]
# Generate possible d values
possible_ds = []
for x in solutions:
d = x ** 3
# Ensure d ends with the given suffix_bytes
if d % modulus != C:
possible_ds.append(d)
return possible_ds
def solve_cubic_congruence_odd(C, k):
"""
Solves x^3 ≡ C mod 256^k for odd C using Hensel's Lemma.
Returns all solutions modulo 256^k.
"""
modulus = 256 ** k
solutions = []
if C % 8 not in {1, 3, 5, 7}:
return []
x = C % 8
for m in range(1, k):
current_modulus = 8 * (256 ** m)
next_modulus = 8 * (256 ** (m + 1)) if m + 1 < k else modulus
f_x = (x ** 3 - C) % next_modulus
if f_x == 0:
pass
else:
df_x = 3 * x * x
inv_df_x = pow(df_x, -1, 8 * (256 ** m))
delta = ( -f_x * inv_df_x ) % next_modulus
x = (x + delta) % next_modulus
solutions.append(x)
return solutions
signature_beginning = sign(b"admin")
add = 34
n = 19127756420097310748445892701499128748802811170394545389894177686010747255446524025209222591763566871458062034924855333220863403640999049071598724090912794271529705833396902663514178629181602063192367835048623464092217195245880040478169420526326907683287975987599829494476148273374575347271069072444176409368157687551295200366781356061118771116864015191263073471755543677323807202265423577097247321828015464869034718847861674633781316574328871975526483891957817409481281089186165564002205714207647105647702204323270996932328410444422984930202645110182771670631186393527961095259316715737096129427647195936911415344769
sig = b"\x04"*add+signature_beginning
test = find_cube_root_with_suffix(sig)
number_x = test[0]
assert n > number_x
signature = integer_nthroot(number_x,3)[0]
data = {
"username": b"admin".decode(),
"signature": b64encode(long_to_bytes(signature)).decode()
}
cookie = b64encode(json.dumps(data).encode()).decode()
print(cookie)