Skip to content

Commit de63ce4

Browse files
committed
v4 support pre-alpha
1 parent 7a90811 commit de63ce4

File tree

4 files changed

+364
-0
lines changed

4 files changed

+364
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ ENV/
8181

8282
# mkdocs documentation
8383
/site
84+
/.vs/ProjectSettings.json

.vs/slnx.sqlite

96 KB
Binary file not shown.

uniswap/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import exceptions
22
from .uniswap import Uniswap, _str_to_addr
3+
from .uniswap4 import Uniswap4
34
from .cli import main

uniswap/uniswap4.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import os
2+
import time
3+
import logging
4+
import functools
5+
from typing import List, Any, Optional, Union, Tuple, Dict
6+
7+
from web3 import Web3
8+
from web3.eth import Contract
9+
from web3.contract import ContractFunction
10+
from web3.exceptions import BadFunctionCallOutput, ContractLogicError
11+
from web3.types import (
12+
TxParams,
13+
Wei,
14+
Address,
15+
ChecksumAddress,
16+
Nonce,
17+
HexBytes,
18+
)
19+
20+
from .types import AddressLike
21+
from .token import ERC20Token
22+
from .tokens import tokens, tokens_rinkeby
23+
from .exceptions import InvalidToken, InsufficientBalance
24+
from .util import (
25+
_str_to_addr,
26+
_addr_to_str,
27+
_validate_address,
28+
_load_contract,
29+
_load_contract_erc20,
30+
is_same_address,
31+
)
32+
from .decorators import supports, check_approval
33+
from .constants import (
34+
_netid_to_name,
35+
_poolmanager_contract_addresses,
36+
ETH_ADDRESS,
37+
)
38+
39+
logger = logging.getLogger(__name__)
40+
41+
42+
class Uniswap4:
43+
"""
44+
Wrapper around Uniswap v4 contracts.
45+
"""
46+
47+
def __init__(
48+
self,
49+
address: Union[AddressLike, str, None],
50+
private_key: Optional[str],
51+
provider: str = None,
52+
web3: Web3 = None,
53+
default_slippage: float = 0.01,
54+
poolmanager_contract_addr: str = None,
55+
) -> None:
56+
"""
57+
:param address: The public address of the ETH wallet to use.
58+
:param private_key: The private key of the ETH wallet to use.
59+
:param provider: Can be optionally set to a Web3 provider URI. If none set, will fall back to the PROVIDER environment variable, or web3 if set.
60+
:param web3: Can be optionally set to a custom Web3 instance.
61+
:param poolmanager_contract_addr: Can be optionally set to override the address of the PoolManager contract.
62+
"""
63+
self.address: AddressLike = _str_to_addr(
64+
address or "0x0000000000000000000000000000000000000000"
65+
)
66+
self.private_key = (
67+
private_key
68+
or "0x0000000000000000000000000000000000000000000000000000000000000000"
69+
)
70+
71+
if web3:
72+
self.w3 = web3
73+
else:
74+
# Initialize web3. Extra provider for testing.
75+
self.provider = provider or os.environ["PROVIDER"]
76+
self.w3 = Web3(
77+
Web3.HTTPProvider(self.provider, request_kwargs={"timeout": 60})
78+
)
79+
80+
netid = int(self.w3.net.version)
81+
if netid in _netid_to_name:
82+
self.network = _netid_to_name[netid]
83+
else:
84+
raise Exception(f"Unknown netid: {netid}")
85+
logger.info(f"Using {self.w3} ('{self.network}')")
86+
87+
self.last_nonce: Nonce = self.w3.eth.get_transaction_count(self.address)
88+
89+
if poolmanager_contract_addr is None:
90+
poolmanager_contract_addr = _poolmanager_contract_addresses[self.network]
91+
92+
self.poolmanager_contract = _load_contract(
93+
self.w3,
94+
abi_name="uniswap-v4/poolmanager",
95+
address=_str_to_addr(poolmanager_contract_addr),
96+
)
97+
98+
if hasattr(self, "poolmanager_contract"):
99+
logger.info(f"Using factory contract: {self.poolmanager_contract}")
100+
101+
# ------ Market --------------------------------------------------------------------
102+
103+
def get_price(
104+
self,
105+
token0: AddressLike, # input token
106+
token1: AddressLike, # output token
107+
qty: int,
108+
fee: int,
109+
route: Optional[List[AddressLike]] = None,
110+
zero_to_one: bool = true,
111+
) -> int:
112+
"""
113+
:if `zero_to_one` is true: given `qty` amount of the input `token0`, returns the maximum output amount of output `token1`.
114+
:if `zero_to_one` is false: returns the minimum amount of `token0` required to buy `qty` amount of `token1`.
115+
"""
116+
117+
# WIP
118+
119+
return 0
120+
121+
# ------ Make Trade ----------------------------------------------------------------
122+
def make_trade(
123+
self,
124+
currency0: ERC20Token,
125+
currency1: ERC20Token,
126+
qty: Union[int, Wei],
127+
fee: int,
128+
tick_spacing: int,
129+
sqrt_price_limit_x96: int = 0,
130+
zero_for_one: bool = true,
131+
hooks: AddressLike = ETH,
132+
) -> HexBytes:
133+
"""
134+
:Swap against the given pool
135+
:
136+
:`currency0`:The lower currency of the pool, sorted numerically
137+
:`currency1`:The higher currency of the pool, sorted numerically
138+
:`fee`: The pool swap fee, capped at 1_000_000. The upper 4 bits determine if the hook sets any fees.
139+
:`tickSpacing`: Ticks that involve positions must be a multiple of tick spacing
140+
:`hooks`: The hooks of the pool
141+
:if `zero_for_one` is true: make a trade by defining the qty of the input token.
142+
:if `zero_for_one` is false: make a trade by defining the qty of the output token.
143+
"""
144+
if currency0 == currency1:
145+
raise ValueError
146+
147+
pool_key = {
148+
"currency0": currency0.address,
149+
"currency1": currency1.address,
150+
"fee": fee,
151+
"tickSpacing": tick_spacing,
152+
"hooks": hooks,
153+
}
154+
155+
swap_params = {
156+
"zeroForOne": zero_for_one,
157+
"amountSpecified": qty,
158+
"sqrtPriceLimitX96": sqrt_price_limit_x96,
159+
}
160+
161+
return self._build_and_send_tx(
162+
self.router.functions.swap(
163+
{
164+
"key": pool_key,
165+
"params": swap_params,
166+
}
167+
),
168+
self._get_tx_params(value=qty),
169+
)
170+
171+
# ------ Wallet balance ------------------------------------------------------------
172+
def get_eth_balance(self) -> Wei:
173+
"""Get the balance of ETH for your address."""
174+
return self.w3.eth.get_balance(self.address)
175+
176+
def get_token_balance(self, token: AddressLike) -> int:
177+
"""Get the balance of a token for your address."""
178+
_validate_address(token)
179+
if _addr_to_str(token) == ETH_ADDRESS:
180+
return self.get_eth_balance()
181+
erc20 = _load_contract_erc20(self.w3, token)
182+
balance: int = erc20.functions.balanceOf(self.address).call()
183+
return balance
184+
185+
# ------ Liquidity -----------------------------------------------------------------
186+
def initialize(
187+
self,
188+
currency0: ERC20Token,
189+
currency1: ERC20Token,
190+
qty: Union[int, Wei],
191+
fee: int,
192+
tick_spacing: int,
193+
hooks: AddressLike,
194+
sqrt_price_limit_x96: int,
195+
) -> HexBytes:
196+
"""
197+
:Initialize the state for a given pool ID
198+
:
199+
:`currency0`:The lower currency of the pool, sorted numerically
200+
:`currency1`:The higher currency of the pool, sorted numerically
201+
:`fee`: The pool swap fee, capped at 1_000_000. The upper 4 bits determine if the hook sets any fees.
202+
:`tickSpacing`: Ticks that involve positions must be a multiple of tick spacing
203+
:`hooks`: The hooks of the pool
204+
"""
205+
if currency0 == currency1:
206+
raise ValueError
207+
208+
pool_key = {
209+
"currency0": currency0.address,
210+
"currency1": currency1.address,
211+
"fee": fee,
212+
"tickSpacing": tick_spacing,
213+
"hooks": hooks,
214+
}
215+
216+
return self._build_and_send_tx(
217+
self.router.functions.initialize(
218+
{
219+
"key": pool_key,
220+
"sqrtPriceX96": sqrt_price_limit_x96,
221+
}
222+
),
223+
self._get_tx_params(value=qty),
224+
)
225+
226+
def modify_position(
227+
self,
228+
currency0: ERC20Token,
229+
currency1: ERC20Token,
230+
qty: Union[int, Wei],
231+
fee: int,
232+
tick_spacing: int,
233+
tick_upper: int,
234+
tick_lower: int,
235+
hooks: AddressLike,
236+
) -> HexBytes:
237+
if currency0 == currency1:
238+
raise ValueError
239+
240+
pool_key = {
241+
"currency0": currency0.address,
242+
"currency1": currency1.address,
243+
"fee": fee,
244+
"tickSpacing": tick_spacing,
245+
"hooks": hooks,
246+
}
247+
248+
modify_position_params = {
249+
"tickLower": tick_lower,
250+
"tickUpper": tick_upper,
251+
"liquidityDelta": qty,
252+
}
253+
254+
return self._build_and_send_tx(
255+
self.router.functions.modifyPosition(
256+
{
257+
"key": pool_key,
258+
"params": modify_position_params,
259+
}
260+
),
261+
self._get_tx_params(value=qty),
262+
)
263+
264+
# ------ Approval Utils ------------------------------------------------------------
265+
def approve(self, token: AddressLike, max_approval: Optional[int] = None) -> None:
266+
"""Give an exchange/router max approval of a token."""
267+
max_approval = self.max_approval_int if not max_approval else max_approval
268+
contract_addr = (
269+
self._exchange_address_from_token(token)
270+
if self.version == 1
271+
else self.router_address
272+
)
273+
function = _load_contract_erc20(self.w3, token).functions.approve(
274+
contract_addr, max_approval
275+
)
276+
logger.warning(f"Approving {_addr_to_str(token)}...")
277+
tx = self._build_and_send_tx(function)
278+
self.w3.eth.wait_for_transaction_receipt(tx, timeout=6000)
279+
280+
# Add extra sleep to let tx propogate correctly
281+
time.sleep(1)
282+
283+
# ------ Tx Utils ------------------------------------------------------------------
284+
def _deadline(self) -> int:
285+
"""Get a predefined deadline. 10min by default (same as the Uniswap SDK)."""
286+
return int(time.time()) + 10 * 60
287+
288+
def _build_and_send_tx(
289+
self, function: ContractFunction, tx_params: Optional[TxParams] = None
290+
) -> HexBytes:
291+
"""Build and send a transaction."""
292+
if not tx_params:
293+
tx_params = self._get_tx_params()
294+
transaction = function.buildTransaction(tx_params)
295+
# Uniswap3 uses 20% margin for transactions
296+
transaction["gas"] = Wei(int(self.w3.eth.estimate_gas(transaction) * 1.2))
297+
signed_txn = self.w3.eth.account.sign_transaction(
298+
transaction, private_key=self.private_key
299+
)
300+
# TODO: This needs to get more complicated if we want to support replacing a transaction
301+
# FIXME: This does not play nice if transactions are sent from other places using the same wallet.
302+
try:
303+
return self.w3.eth.send_raw_transaction(signed_txn.rawTransaction)
304+
finally:
305+
logger.debug(f"nonce: {tx_params['nonce']}")
306+
self.last_nonce = Nonce(tx_params["nonce"] + 1)
307+
308+
def _get_tx_params(self, value: Wei = Wei(0)) -> TxParams:
309+
"""Get generic transaction parameters."""
310+
return {
311+
"from": _addr_to_str(self.address),
312+
"value": value,
313+
"nonce": max(
314+
self.last_nonce, self.w3.eth.get_transaction_count(self.address)
315+
),
316+
}
317+
318+
# ------ Helpers ------------------------------------------------------------
319+
320+
def get_token(self, address: AddressLike, abi_name: str = "erc20") -> ERC20Token:
321+
"""
322+
Retrieves metadata from the ERC20 contract of a given token, like its name, symbol, and decimals.
323+
"""
324+
# FIXME: This function should always return the same output for the same input
325+
# and would therefore benefit from caching
326+
if address == ETH_ADDRESS:
327+
return ERC20Token("ETH", ETH_ADDRESS, "Ether", 18)
328+
token_contract = _load_contract(self.w3, abi_name, address=address)
329+
try:
330+
_name = token_contract.functions.name().call()
331+
_symbol = token_contract.functions.symbol().call()
332+
decimals = token_contract.functions.decimals().call()
333+
except Exception as e:
334+
logger.warning(
335+
f"Exception occurred while trying to get token {_addr_to_str(address)}: {e}"
336+
)
337+
raise InvalidToken(address)
338+
try:
339+
name = _name.decode()
340+
except:
341+
name = _name
342+
try:
343+
symbol = _symbol.decode()
344+
except:
345+
symbol = _symbol
346+
return ERC20Token(symbol, address, name, decimals)
347+
348+
# ------ Test utilities ------------------------------------------------------------
349+
350+
def _get_token_addresses(self) -> Dict[str, ChecksumAddress]:
351+
"""
352+
Returns a dict with addresses for tokens for the current net.
353+
Used in testing.
354+
"""
355+
netid = int(self.w3.net.version)
356+
netname = _netid_to_name[netid]
357+
if netname == "mainnet":
358+
return tokens
359+
elif netname == "rinkeby":
360+
return tokens_rinkeby
361+
else:
362+
raise Exception(f"Unknown net '{netname}'")

0 commit comments

Comments
 (0)