From: 
Subject: Debian changes

The Debian packaging of python-coinbase-advanced-py is maintained in git, using a workflow
similar to the one described in dgit-maint-merge(7).
The Debian delta is represented by this one combined patch; there isn't a
patch queue that can be represented as a quilt series.

A detailed breakdown of the changes is available from their canonical
representation -- git commits in the packaging repository.
For example, to see the changes made by the Debian maintainer in the first
upload of upstream version 1.2.3, you could use:

    % git clone https://git.dgit.debian.org/python-coinbase-advanced-py
    % cd python-coinbase-advanced-py
    % git log --oneline 1.2.3..debian/1.2.3-1 -- . ':!debian'

(If you have dgit, use `dgit clone python-coinbase-advanced-py`, rather than plain `git clone`.)

We don't use debian/source/options single-debian-patch because it has bugs.
Therefore, NMUs etc. may nevertheless have made additional patches.

---

diff --git a/coinbase/websocket/websocket_base.py b/coinbase/websocket/websocket_base.py
index a65b8f4..8578058 100644
--- a/coinbase/websocket/websocket_base.py
+++ b/coinbase/websocket/websocket_base.py
@@ -1,4 +1,5 @@
 import asyncio
+import inspect
 import json
 import logging
 import os
@@ -7,6 +8,7 @@ import threading
 import time
 from multiprocessing import AuthenticationError
 from typing import IO, Callable, List, Optional, Union
+from urllib.parse import urlparse
 
 import backoff
 import websockets
@@ -29,6 +31,24 @@ from coinbase.constants import (
 logger = get_logger("coinbase.WSClient")
 
 
+def _websocket_headers_arg(headers):
+    connect_params = inspect.signature(websockets.connect).parameters
+    if "additional_headers" in connect_params:
+        return {"additional_headers": headers}
+    return {"extra_headers": headers}
+
+
+def _websocket_proxy_arg(url):
+    connect_params = inspect.signature(websockets.connect).parameters
+    if "proxy" not in connect_params:
+        return {}
+
+    hostname = urlparse(url).hostname
+    if hostname in {"localhost", "127.0.0.1", "::1"}:
+        return {"proxy": None}
+    return {}
+
+
 class WSClientException(Exception):
     """
     **WSClientException**
@@ -146,8 +166,13 @@ class WSBase(APIBase):
                 open_timeout=self.timeout,
                 max_size=self.max_size,
                 user_agent_header=USER_AGENT,
-                extra_headers=headers,
-                ssl=ssl.SSLContext() if self.base_url.startswith("wss://") else None,
+                **_websocket_headers_arg(headers),
+                **_websocket_proxy_arg(self.base_url),
+                ssl=(
+                    ssl.create_default_context()
+                    if self.base_url.startswith("wss://")
+                    else None
+                ),
             )
             logger.debug("Successfully connected to %s", self.base_url)
 
@@ -460,7 +485,15 @@ class WSBase(APIBase):
         """
         :meta private:
         """
-        return self.websocket and self.websocket.open
+        if not self.websocket:
+            return False
+
+        open_attr = getattr(self.websocket, "open", None)
+        if open_attr is not None:
+            return open_attr
+
+        state = getattr(self.websocket, "state", None)
+        return getattr(state, "name", None) == "OPEN" or state == 1
 
     async def _resubscribe(self):
         """
diff --git a/tests/websocket/mock_ws_server.py b/tests/websocket/mock_ws_server.py
index beea62d..f53ab16 100644
--- a/tests/websocket/mock_ws_server.py
+++ b/tests/websocket/mock_ws_server.py
@@ -5,7 +5,7 @@ import websockets
 
 
 class MockWebSocketServer:
-    def __init__(self, host="localhost", port=8765):
+    def __init__(self, host="127.0.0.1", port=0):
         self.start_server = None
         self.host = host
         self.port = port
@@ -29,6 +29,7 @@ class MockWebSocketServer:
     async def start(self):
         self.start_server = self.initialize_server()
         self.server = await self.start_server
+        self.port = self.server.sockets[0].getsockname()[1]
 
     async def stop(self):
         WebSocketTask = namedtuple("WebSocketTask", ["ws", "task"])
diff --git a/tests/websocket/test_websocket_base.py b/tests/websocket/test_websocket_base.py
index 2bbecec..b622af2 100644
--- a/tests/websocket/test_websocket_base.py
+++ b/tests/websocket/test_websocket_base.py
@@ -1,4 +1,5 @@
 import asyncio
+import contextlib
 import json
 import unittest
 from unittest.mock import AsyncMock, patch
@@ -441,7 +442,10 @@ class WSBaseTest(unittest.IsolatedAsyncioTestCase):
         await self.ws_public.close_async()
         self.mock_websocket.close.assert_awaited_once()
 
-    def test_err_calling_private_unauthenticated(self):
+    @patch("websockets.connect", new_callable=AsyncMock)
+    def test_err_calling_private_unauthenticated(self, mock_connect):
+        mock_connect.return_value = self.mock_websocket
+
         # open
         self.ws_public.open()
         self.assertIsNotNone(self.ws_public.websocket)
@@ -453,6 +457,8 @@ class WSBaseTest(unittest.IsolatedAsyncioTestCase):
 
 class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
     # tests that run against a mock websocket server to simulate disconnections
+    MESSAGE_TIMEOUT = 5
+
     async def mock_send(self, message):
         self.messages_queue.put_nowait(message)
 
@@ -466,7 +472,7 @@ class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
         self.ws = WSClient(
             TEST_API_KEY,
             TEST_API_SECRET,
-            base_url="ws://localhost:8765",
+            base_url=f"ws://{self.server.host}:{self.server.port}",
             on_message=on_message,
             retry=False,
         )
@@ -475,7 +481,17 @@ class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
     # self.ws._retry_factor = 1.5
     # self.ws._retry_max = 5
 
+    async def get_message(self):
+        return await asyncio.wait_for(
+            self.messages_queue.get(), timeout=self.MESSAGE_TIMEOUT
+        )
+
     async def asyncTearDown(self):
+        if self.ws._is_websocket_open():
+            await self.ws.close_async()
+        if self.ws._task:
+            with contextlib.suppress(asyncio.TimeoutError):
+                await asyncio.wait_for(self.ws._task, timeout=self.MESSAGE_TIMEOUT)
         await self.server.stop()
 
     async def test_disconnect_error(self):
@@ -495,6 +511,9 @@ class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
         # tests that client can automatically reconnect after a WSClientConnectionClosedException
 
         self.ws.retry = True
+        self.ws._retry_base = 1
+        self.ws._retry_factor = 0.05
+        self.ws._retry_max_tries = 50
 
         # Open WebSocket connection
         await self.ws.open_async()
@@ -502,15 +521,15 @@ class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
             product_ids=["BTC-USD", "ETH-USD"], channels=["ticker"]
         )
 
-        await self.messages_queue.get()
+        await self.get_message()
         await self.ws.subscribe_async(product_ids=["BTC-USD"], channels=["heartbeats"])
-        await self.messages_queue.get()
+        await self.get_message()
 
         # disconnect and restart the server
         await self.server.restart_with_error()
 
         # assert resubscribe messages
-        resubscribe_1 = await self.messages_queue.get()
+        resubscribe_1 = await self.get_message()
         resubscribe_1_json = json.loads(resubscribe_1)
         self.assertEqual(resubscribe_1_json["type"], SUBSCRIBE_MESSAGE_TYPE)
         self.assertEqual(
@@ -518,7 +537,7 @@ class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
         )
         self.assertEqual(resubscribe_1_json["channel"], "ticker")
 
-        resubscribe_2 = await self.messages_queue.get()
+        resubscribe_2 = await self.get_message()
         resubscribe_2_json = json.loads(resubscribe_2)
         self.assertEqual(resubscribe_2_json["type"], SUBSCRIBE_MESSAGE_TYPE)
         self.assertEqual(resubscribe_2_json["product_ids"], ["BTC-USD"])
@@ -535,9 +554,9 @@ class WSDisconnectionTests(unittest.IsolatedAsyncioTestCase):
             product_ids=["BTC-USD", "ETH-USD"], channels=["ticker"]
         )
 
-        await self.messages_queue.get()
+        await self.get_message()
         await self.ws.subscribe_async(product_ids=["BTC-USD"], channels=["heartbeats"])
-        await self.messages_queue.get()
+        await self.get_message()
 
         with self.assertRaises(WSClientConnectionClosedException):
             # disconnect and restart the server
