|
| 1 | +import os |
| 2 | +import time |
| 3 | +import logging |
| 4 | +import functools |
| 5 | +from typing import List, Any, Optional, Union, Tuple, Dict |
| 6 | + |
| 7 | +from web3 import Web3 |
| 8 | +from web3.eth import Contract |
| 9 | +from web3.contract import ContractFunction |
| 10 | +from web3.exceptions import BadFunctionCallOutput, ContractLogicError |
| 11 | +from web3.types import ( |
| 12 | + TxParams, |
| 13 | + Wei, |
| 14 | + Address, |
| 15 | + ChecksumAddress, |
| 16 | + Nonce, |
| 17 | + HexBytes, |
| 18 | +) |
| 19 | + |
| 20 | +from .types import AddressLike |
| 21 | +from .token import ERC20Token |
| 22 | +from .tokens import tokens, tokens_rinkeby |
| 23 | +from .exceptions import InvalidToken, InsufficientBalance |
| 24 | +from .util import ( |
| 25 | + _str_to_addr, |
| 26 | + _addr_to_str, |
| 27 | + _validate_address, |
| 28 | + _load_contract, |
| 29 | + _load_contract_erc20, |
| 30 | + is_same_address, |
| 31 | +) |
| 32 | +from .decorators import supports, check_approval |
| 33 | +from .constants import ( |
| 34 | + _netid_to_name, |
| 35 | + _poolmanager_contract_addresses, |
| 36 | + ETH_ADDRESS, |
| 37 | +) |
| 38 | + |
| 39 | +logger = logging.getLogger(__name__) |
| 40 | + |
| 41 | + |
| 42 | +class Uniswap4: |
| 43 | + """ |
| 44 | + Wrapper around Uniswap v4 contracts. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + address: Union[AddressLike, str, None], |
| 50 | + private_key: Optional[str], |
| 51 | + provider: str = None, |
| 52 | + web3: Web3 = None, |
| 53 | + default_slippage: float = 0.01, |
| 54 | + poolmanager_contract_addr: str = None, |
| 55 | + ) -> None: |
| 56 | + """ |
| 57 | + :param address: The public address of the ETH wallet to use. |
| 58 | + :param private_key: The private key of the ETH wallet to use. |
| 59 | + :param provider: Can be optionally set to a Web3 provider URI. If none set, will fall back to the PROVIDER environment variable, or web3 if set. |
| 60 | + :param web3: Can be optionally set to a custom Web3 instance. |
| 61 | + :param poolmanager_contract_addr: Can be optionally set to override the address of the PoolManager contract. |
| 62 | + """ |
| 63 | + self.address: AddressLike = _str_to_addr( |
| 64 | + address or "0x0000000000000000000000000000000000000000" |
| 65 | + ) |
| 66 | + self.private_key = ( |
| 67 | + private_key |
| 68 | + or "0x0000000000000000000000000000000000000000000000000000000000000000" |
| 69 | + ) |
| 70 | + |
| 71 | + if web3: |
| 72 | + self.w3 = web3 |
| 73 | + else: |
| 74 | + # Initialize web3. Extra provider for testing. |
| 75 | + self.provider = provider or os.environ["PROVIDER"] |
| 76 | + self.w3 = Web3( |
| 77 | + Web3.HTTPProvider(self.provider, request_kwargs={"timeout": 60}) |
| 78 | + ) |
| 79 | + |
| 80 | + netid = int(self.w3.net.version) |
| 81 | + if netid in _netid_to_name: |
| 82 | + self.network = _netid_to_name[netid] |
| 83 | + else: |
| 84 | + raise Exception(f"Unknown netid: {netid}") |
| 85 | + logger.info(f"Using {self.w3} ('{self.network}')") |
| 86 | + |
| 87 | + self.last_nonce: Nonce = self.w3.eth.get_transaction_count(self.address) |
| 88 | + |
| 89 | + if poolmanager_contract_addr is None: |
| 90 | + poolmanager_contract_addr = _poolmanager_contract_addresses[self.network] |
| 91 | + |
| 92 | + self.poolmanager_contract = _load_contract( |
| 93 | + self.w3, |
| 94 | + abi_name="uniswap-v4/poolmanager", |
| 95 | + address=_str_to_addr(poolmanager_contract_addr), |
| 96 | + ) |
| 97 | + |
| 98 | + if hasattr(self, "poolmanager_contract"): |
| 99 | + logger.info(f"Using factory contract: {self.poolmanager_contract}") |
| 100 | + |
| 101 | + # ------ Market -------------------------------------------------------------------- |
| 102 | + |
| 103 | + def get_price( |
| 104 | + self, |
| 105 | + token0: AddressLike, # input token |
| 106 | + token1: AddressLike, # output token |
| 107 | + qty: int, |
| 108 | + fee: int, |
| 109 | + route: Optional[List[AddressLike]] = None, |
| 110 | + zero_to_one: bool = true, |
| 111 | + ) -> int: |
| 112 | + """ |
| 113 | + :if `zero_to_one` is true: given `qty` amount of the input `token0`, returns the maximum output amount of output `token1`. |
| 114 | + :if `zero_to_one` is false: returns the minimum amount of `token0` required to buy `qty` amount of `token1`. |
| 115 | + """ |
| 116 | + |
| 117 | + # WIP |
| 118 | + |
| 119 | + return 0 |
| 120 | + |
| 121 | + # ------ Make Trade ---------------------------------------------------------------- |
| 122 | + def make_trade( |
| 123 | + self, |
| 124 | + currency0: ERC20Token, |
| 125 | + currency1: ERC20Token, |
| 126 | + qty: Union[int, Wei], |
| 127 | + fee: int, |
| 128 | + tick_spacing: int, |
| 129 | + sqrt_price_limit_x96: int = 0, |
| 130 | + zero_for_one: bool = true, |
| 131 | + hooks: AddressLike = ETH, |
| 132 | + ) -> HexBytes: |
| 133 | + """ |
| 134 | + :Swap against the given pool |
| 135 | + : |
| 136 | + :`currency0`:The lower currency of the pool, sorted numerically |
| 137 | + :`currency1`:The higher currency of the pool, sorted numerically |
| 138 | + :`fee`: The pool swap fee, capped at 1_000_000. The upper 4 bits determine if the hook sets any fees. |
| 139 | + :`tickSpacing`: Ticks that involve positions must be a multiple of tick spacing |
| 140 | + :`hooks`: The hooks of the pool |
| 141 | + :if `zero_for_one` is true: make a trade by defining the qty of the input token. |
| 142 | + :if `zero_for_one` is false: make a trade by defining the qty of the output token. |
| 143 | + """ |
| 144 | + if currency0 == currency1: |
| 145 | + raise ValueError |
| 146 | + |
| 147 | + pool_key = { |
| 148 | + "currency0": currency0.address, |
| 149 | + "currency1": currency1.address, |
| 150 | + "fee": fee, |
| 151 | + "tickSpacing": tick_spacing, |
| 152 | + "hooks": hooks, |
| 153 | + } |
| 154 | + |
| 155 | + swap_params = { |
| 156 | + "zeroForOne": zero_for_one, |
| 157 | + "amountSpecified": qty, |
| 158 | + "sqrtPriceLimitX96": sqrt_price_limit_x96, |
| 159 | + } |
| 160 | + |
| 161 | + return self._build_and_send_tx( |
| 162 | + self.router.functions.swap( |
| 163 | + { |
| 164 | + "key": pool_key, |
| 165 | + "params": swap_params, |
| 166 | + } |
| 167 | + ), |
| 168 | + self._get_tx_params(value=qty), |
| 169 | + ) |
| 170 | + |
| 171 | + # ------ Wallet balance ------------------------------------------------------------ |
| 172 | + def get_eth_balance(self) -> Wei: |
| 173 | + """Get the balance of ETH for your address.""" |
| 174 | + return self.w3.eth.get_balance(self.address) |
| 175 | + |
| 176 | + def get_token_balance(self, token: AddressLike) -> int: |
| 177 | + """Get the balance of a token for your address.""" |
| 178 | + _validate_address(token) |
| 179 | + if _addr_to_str(token) == ETH_ADDRESS: |
| 180 | + return self.get_eth_balance() |
| 181 | + erc20 = _load_contract_erc20(self.w3, token) |
| 182 | + balance: int = erc20.functions.balanceOf(self.address).call() |
| 183 | + return balance |
| 184 | + |
| 185 | + # ------ Liquidity ----------------------------------------------------------------- |
| 186 | + def initialize( |
| 187 | + self, |
| 188 | + currency0: ERC20Token, |
| 189 | + currency1: ERC20Token, |
| 190 | + qty: Union[int, Wei], |
| 191 | + fee: int, |
| 192 | + tick_spacing: int, |
| 193 | + hooks: AddressLike, |
| 194 | + sqrt_price_limit_x96: int, |
| 195 | + ) -> HexBytes: |
| 196 | + """ |
| 197 | + :Initialize the state for a given pool ID |
| 198 | + : |
| 199 | + :`currency0`:The lower currency of the pool, sorted numerically |
| 200 | + :`currency1`:The higher currency of the pool, sorted numerically |
| 201 | + :`fee`: The pool swap fee, capped at 1_000_000. The upper 4 bits determine if the hook sets any fees. |
| 202 | + :`tickSpacing`: Ticks that involve positions must be a multiple of tick spacing |
| 203 | + :`hooks`: The hooks of the pool |
| 204 | + """ |
| 205 | + if currency0 == currency1: |
| 206 | + raise ValueError |
| 207 | + |
| 208 | + pool_key = { |
| 209 | + "currency0": currency0.address, |
| 210 | + "currency1": currency1.address, |
| 211 | + "fee": fee, |
| 212 | + "tickSpacing": tick_spacing, |
| 213 | + "hooks": hooks, |
| 214 | + } |
| 215 | + |
| 216 | + return self._build_and_send_tx( |
| 217 | + self.router.functions.initialize( |
| 218 | + { |
| 219 | + "key": pool_key, |
| 220 | + "sqrtPriceX96": sqrt_price_limit_x96, |
| 221 | + } |
| 222 | + ), |
| 223 | + self._get_tx_params(value=qty), |
| 224 | + ) |
| 225 | + |
| 226 | + def modify_position( |
| 227 | + self, |
| 228 | + currency0: ERC20Token, |
| 229 | + currency1: ERC20Token, |
| 230 | + qty: Union[int, Wei], |
| 231 | + fee: int, |
| 232 | + tick_spacing: int, |
| 233 | + tick_upper: int, |
| 234 | + tick_lower: int, |
| 235 | + hooks: AddressLike, |
| 236 | + ) -> HexBytes: |
| 237 | + if currency0 == currency1: |
| 238 | + raise ValueError |
| 239 | + |
| 240 | + pool_key = { |
| 241 | + "currency0": currency0.address, |
| 242 | + "currency1": currency1.address, |
| 243 | + "fee": fee, |
| 244 | + "tickSpacing": tick_spacing, |
| 245 | + "hooks": hooks, |
| 246 | + } |
| 247 | + |
| 248 | + modify_position_params = { |
| 249 | + "tickLower": tick_lower, |
| 250 | + "tickUpper": tick_upper, |
| 251 | + "liquidityDelta": qty, |
| 252 | + } |
| 253 | + |
| 254 | + return self._build_and_send_tx( |
| 255 | + self.router.functions.modifyPosition( |
| 256 | + { |
| 257 | + "key": pool_key, |
| 258 | + "params": modify_position_params, |
| 259 | + } |
| 260 | + ), |
| 261 | + self._get_tx_params(value=qty), |
| 262 | + ) |
| 263 | + |
| 264 | + # ------ Approval Utils ------------------------------------------------------------ |
| 265 | + def approve(self, token: AddressLike, max_approval: Optional[int] = None) -> None: |
| 266 | + """Give an exchange/router max approval of a token.""" |
| 267 | + max_approval = self.max_approval_int if not max_approval else max_approval |
| 268 | + contract_addr = ( |
| 269 | + self._exchange_address_from_token(token) |
| 270 | + if self.version == 1 |
| 271 | + else self.router_address |
| 272 | + ) |
| 273 | + function = _load_contract_erc20(self.w3, token).functions.approve( |
| 274 | + contract_addr, max_approval |
| 275 | + ) |
| 276 | + logger.warning(f"Approving {_addr_to_str(token)}...") |
| 277 | + tx = self._build_and_send_tx(function) |
| 278 | + self.w3.eth.wait_for_transaction_receipt(tx, timeout=6000) |
| 279 | + |
| 280 | + # Add extra sleep to let tx propogate correctly |
| 281 | + time.sleep(1) |
| 282 | + |
| 283 | + # ------ Tx Utils ------------------------------------------------------------------ |
| 284 | + def _deadline(self) -> int: |
| 285 | + """Get a predefined deadline. 10min by default (same as the Uniswap SDK).""" |
| 286 | + return int(time.time()) + 10 * 60 |
| 287 | + |
| 288 | + def _build_and_send_tx( |
| 289 | + self, function: ContractFunction, tx_params: Optional[TxParams] = None |
| 290 | + ) -> HexBytes: |
| 291 | + """Build and send a transaction.""" |
| 292 | + if not tx_params: |
| 293 | + tx_params = self._get_tx_params() |
| 294 | + transaction = function.buildTransaction(tx_params) |
| 295 | + # Uniswap3 uses 20% margin for transactions |
| 296 | + transaction["gas"] = Wei(int(self.w3.eth.estimate_gas(transaction) * 1.2)) |
| 297 | + signed_txn = self.w3.eth.account.sign_transaction( |
| 298 | + transaction, private_key=self.private_key |
| 299 | + ) |
| 300 | + # TODO: This needs to get more complicated if we want to support replacing a transaction |
| 301 | + # FIXME: This does not play nice if transactions are sent from other places using the same wallet. |
| 302 | + try: |
| 303 | + return self.w3.eth.send_raw_transaction(signed_txn.rawTransaction) |
| 304 | + finally: |
| 305 | + logger.debug(f"nonce: {tx_params['nonce']}") |
| 306 | + self.last_nonce = Nonce(tx_params["nonce"] + 1) |
| 307 | + |
| 308 | + def _get_tx_params(self, value: Wei = Wei(0)) -> TxParams: |
| 309 | + """Get generic transaction parameters.""" |
| 310 | + return { |
| 311 | + "from": _addr_to_str(self.address), |
| 312 | + "value": value, |
| 313 | + "nonce": max( |
| 314 | + self.last_nonce, self.w3.eth.get_transaction_count(self.address) |
| 315 | + ), |
| 316 | + } |
| 317 | + |
| 318 | + # ------ Helpers ------------------------------------------------------------ |
| 319 | + |
| 320 | + def get_token(self, address: AddressLike, abi_name: str = "erc20") -> ERC20Token: |
| 321 | + """ |
| 322 | + Retrieves metadata from the ERC20 contract of a given token, like its name, symbol, and decimals. |
| 323 | + """ |
| 324 | + # FIXME: This function should always return the same output for the same input |
| 325 | + # and would therefore benefit from caching |
| 326 | + if address == ETH_ADDRESS: |
| 327 | + return ERC20Token("ETH", ETH_ADDRESS, "Ether", 18) |
| 328 | + token_contract = _load_contract(self.w3, abi_name, address=address) |
| 329 | + try: |
| 330 | + _name = token_contract.functions.name().call() |
| 331 | + _symbol = token_contract.functions.symbol().call() |
| 332 | + decimals = token_contract.functions.decimals().call() |
| 333 | + except Exception as e: |
| 334 | + logger.warning( |
| 335 | + f"Exception occurred while trying to get token {_addr_to_str(address)}: {e}" |
| 336 | + ) |
| 337 | + raise InvalidToken(address) |
| 338 | + try: |
| 339 | + name = _name.decode() |
| 340 | + except: |
| 341 | + name = _name |
| 342 | + try: |
| 343 | + symbol = _symbol.decode() |
| 344 | + except: |
| 345 | + symbol = _symbol |
| 346 | + return ERC20Token(symbol, address, name, decimals) |
| 347 | + |
| 348 | + # ------ Test utilities ------------------------------------------------------------ |
| 349 | + |
| 350 | + def _get_token_addresses(self) -> Dict[str, ChecksumAddress]: |
| 351 | + """ |
| 352 | + Returns a dict with addresses for tokens for the current net. |
| 353 | + Used in testing. |
| 354 | + """ |
| 355 | + netid = int(self.w3.net.version) |
| 356 | + netname = _netid_to_name[netid] |
| 357 | + if netname == "mainnet": |
| 358 | + return tokens |
| 359 | + elif netname == "rinkeby": |
| 360 | + return tokens_rinkeby |
| 361 | + else: |
| 362 | + raise Exception(f"Unknown net '{netname}'") |
0 commit comments