OAuth refresh

This commit is contained in:
Zotify 2024-08-15 16:16:50 +12:00
parent 446c8c2a52
commit faca12783e
2 changed files with 59 additions and 19 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from enum import IntEnum
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
@ -69,7 +70,7 @@ class Session(LibrespotSession):
def __init__( def __init__(
self, self,
session_builder: LibrespotSession.Builder, session_builder: LibrespotSession.Builder,
token: TokenProvider.StoredToken, oauth: OAuth,
language: str = "en", language: str = "en",
) -> None: ) -> None:
""" """
@ -89,7 +90,7 @@ class Session(LibrespotSession):
), ),
ApResolver.get_random_accesspoint(), ApResolver.get_random_accesspoint(),
) )
self.__token = token self.__oauth = oauth
self.__language = language self.__language = language
self.connect() self.connect()
self.authenticate(session_builder.login_credentials) self.authenticate(session_builder.login_credentials)
@ -112,8 +113,7 @@ class Session(LibrespotSession):
.build() .build()
) )
session = LibrespotSession.Builder(conf).stored_file(str(cred_file)) session = LibrespotSession.Builder(conf).stored_file(str(cred_file))
token = session.login_credentials.auth_data # TODO: this is wrong return Session(session, OAuth(), language) # TODO
return Session(session, token, language)
@staticmethod @staticmethod
def from_oauth( def from_oauth(
@ -148,7 +148,7 @@ class Session(LibrespotSession):
typ=Authentication.AuthenticationType.values()[3], typ=Authentication.AuthenticationType.values()[3],
auth_data=token.access_token.encode(), auth_data=token.access_token.encode(),
) )
return Session(session, token, language) return Session(session, auth, language)
def __get_playable( def __get_playable(
self, playable_id: PlayableId, quality: Quality self, playable_id: PlayableId, quality: Quality
@ -188,9 +188,9 @@ class Session(LibrespotSession):
self.api(), self.api(),
) )
def token(self) -> TokenProvider.StoredToken: def oauth(self) -> OAuth:
"""Returns API token""" """Returns OAuth service"""
return self.__token return self.__oauth
def language(self) -> str: def language(self) -> str:
"""Returns session language""" """Returns session language"""
@ -288,7 +288,7 @@ class TokenProvider(LibrespotTokenProvider):
self._session = session self._session = session
def get_token(self, *scopes) -> TokenProvider.StoredToken: def get_token(self, *scopes) -> TokenProvider.StoredToken:
return self._session.token() return self._session.oauth().get_token()
class StoredToken(LibrespotTokenProvider.StoredToken): class StoredToken(LibrespotTokenProvider.StoredToken):
def __init__(self, obj): def __init__(self, obj):
@ -309,6 +309,11 @@ class OAuth:
self.__server_thread.start() self.__server_thread.start()
def get_authorization_url(self) -> str: def get_authorization_url(self) -> str:
"""
Generate OAuth URL
Returns:
OAuth URL
"""
self.__code_verifier = generate_code_verifier() self.__code_verifier = generate_code_verifier()
code_challenge = get_code_challenge(self.__code_verifier) code_challenge = get_code_challenge(self.__code_verifier)
params = { params = {
@ -322,19 +327,48 @@ class OAuth:
return f"{AUTH_URL}authorize?{urlencode(params)}" return f"{AUTH_URL}authorize?{urlencode(params)}"
def await_token(self) -> TokenProvider.StoredToken: def await_token(self) -> TokenProvider.StoredToken:
"""
Blocks until server thread gets token
Returns:
StoredToken
"""
self.__server_thread.join() self.__server_thread.join()
return self.__token return self.__token
def set_token(self, code: str) -> None: def get_token(self) -> TokenProvider.StoredToken:
"""
Gets a valid token
Returns:
StoredToken
"""
if self.__token is None:
raise RuntimeError("Session isn't authenticated!")
elif self.__token.expired():
self.set_token(self.__token.refresh_token, OAuth.RequestType.REFRESH)
return self.__token
def set_token(self, code: str, request_type: RequestType) -> None:
"""
Fetches and sets stored token
Returns:
StoredToken
"""
token_url = f"{AUTH_URL}api/token" token_url = f"{AUTH_URL}api/token"
headers = {"Content-Type": "application/x-www-form-urlencoded"} headers = {"Content-Type": "application/x-www-form-urlencoded"}
body = { if request_type == OAuth.RequestType.LOGIN:
"grant_type": "authorization_code", body = {
"code": code, "grant_type": "authorization_code",
"redirect_uri": REDIRECT_URI, "code": code,
"client_id": CLIENT_ID, "redirect_uri": REDIRECT_URI,
"code_verifier": self.__code_verifier, "client_id": CLIENT_ID,
} "code_verifier": self.__code_verifier,
}
elif request_type == OAuth.RequestType.REFRESH:
body = {
"grant_type": "refresh_token",
"refresh_token": code,
"client_id": CLIENT_ID,
}
response = post(token_url, headers=headers, data=body) response = post(token_url, headers=headers, data=body)
if response.status_code != 200: if response.status_code != 200:
raise IOError( raise IOError(
@ -348,6 +382,10 @@ class OAuth:
httpd.authenticator = self httpd.authenticator = self
httpd.serve_forever() httpd.serve_forever()
class RequestType(IntEnum):
LOGIN = 0
REFRESH = 1
class OAuthHTTPServer(HTTPServer): class OAuthHTTPServer(HTTPServer):
authenticator: OAuth authenticator: OAuth
@ -371,7 +409,9 @@ class OAuth:
if code: if code:
if isinstance(self.server, OAuth.OAuthHTTPServer): if isinstance(self.server, OAuth.OAuthHTTPServer):
self.server.authenticator.set_token(code[0]) self.server.authenticator.set_token(
code[0], OAuth.RequestType.LOGIN
)
self.send_response(200) self.send_response(200)
self.send_header("Content-type", "text/html") self.send_header("Content-type", "text/html")
self.end_headers() self.end_headers()

View file

@ -106,7 +106,7 @@ class Selection:
def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None: def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None:
arg_range = range(len(args)) arg_range = range(len(args))
category_str = " # " + " ".join("{:<38}" for _ in arg_range) category_str = "# " + " ".join("{:<38}" for _ in arg_range)
print(category_str.format(*[s.upper() for s in list(args)])) print(category_str.format(*[s.upper() for s in list(args)]))
for item in items: for item in items:
count += 1 count += 1