source: trunk/DirectShowSpy/Common.h @ 267

Last change on this file since 267 was 267, checked in by roman, 10 years ago

Updated source code/copyright notice

  • Property svn:keywords set to Id
File size: 15.9 KB
Line 
1////////////////////////////////////////////////////////////
2// Copyright (C) Roman Ryltsov, 2008-2014
3// Created by Roman Ryltsov roman@alax.info
4//
5// This source code is published to complement DirectShowSpy developer powertoy
6// and demonstrate the internal use of APIs and tricks powering the tool. It is
7// allowed to freely re-use the portions of the code in other projects, commercial or otherwise
8// (provided that you don’t pretend that you wrote the original tool).
9//
10// Please keep in mind that DirectShowSpy is a developer tool, it is strongly recommended
11// that it is not shipped with release grade software. The advise applies to hooking methods
12// used by DirectShowSpy in general as well.
13
14#pragma once
15
16////////////////////////////////////////////////////////////
17// CProcessTokenPrivileges
18
19class CProcessTokenPrivileges
20{
21public:
22        CAccessToken m_ProcessToken;
23        CTokenPrivileges m_OriginalProcessTokenPrivileges;
24        BOOL m_bTakeOwnershipPrivilegeEnabled;
25        BOOL m_bRestorePrivilegeEnabled;
26
27public:
28// CProcessTokenPrivileges
29        CProcessTokenPrivileges() throw() :
30                m_bTakeOwnershipPrivilegeEnabled(FALSE),
31                m_bRestorePrivilegeEnabled(FALSE)
32        {
33        }
34        ~CProcessTokenPrivileges()
35        {
36                if(m_bTakeOwnershipPrivilegeEnabled)
37                {
38                        __E(m_ProcessToken.EnableDisablePrivileges(m_OriginalProcessTokenPrivileges));
39                        m_bRestorePrivilegeEnabled = FALSE;
40                        m_bTakeOwnershipPrivilegeEnabled = FALSE;
41                }
42        }
43        VOID Adjust()
44        {
45                _A(!m_ProcessToken.GetHandle());
46                // NOTE: Enable SE_TAKE_OWNERSHIP_NAME and SE_RESTORE_NAME privileges in order to be able to access registry key permissions
47                __E(m_ProcessToken.GetProcessToken(TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES));
48                __E(m_ProcessToken.EnablePrivilege(SE_TAKE_OWNERSHIP_NAME, &m_OriginalProcessTokenPrivileges));
49                m_bTakeOwnershipPrivilegeEnabled = TRUE;
50                __E(m_ProcessToken.EnablePrivilege(SE_RESTORE_NAME));
51                m_bRestorePrivilegeEnabled = TRUE;
52        }
53};
54
55////////////////////////////////////////////////////////////
56// CClassIdentifierRegKeySecurity
57
58class CClassIdentifierRegKeySecurity
59{
60public:
61        CLSID m_ClassIdentifier;
62        CRegKey m_Key;
63        CSecurityDesc m_OriginalSecurityDescriptor;
64
65public:
66// CClassIdentifierRegKeySecurity
67        static VOID GetRegKeySecurityDescriptor(CRegKey& Key, SECURITY_INFORMATION SecurityInformation, CSecurityDesc& SecurityDescriptor)
68        {
69                DWORD nSecurityDescriptorSize = 0;
70                Key.GetKeySecurity(SecurityInformation, NULL, &nSecurityDescriptorSize);
71                CTempBuffer<SECURITY_DESCRIPTOR> pSecurityDescriptor;
72                _W(pSecurityDescriptor.AllocateBytes(nSecurityDescriptorSize));
73                __C(HRESULT_FROM_WIN32(Key.GetKeySecurity(SecurityInformation, pSecurityDescriptor, &nSecurityDescriptorSize)));
74                SecurityDescriptor = *pSecurityDescriptor;
75        }
76        static CString StringFromSecurityDescriptor(CSecurityDesc& SecurityDescriptor)
77        {
78                CString sSecurityDescriptorString;
79                __E(SecurityDescriptor.ToString(&sSecurityDescriptorString));
80                return sSecurityDescriptorString;
81        }
82        static CString StringFromRegKeySecurityDescriptor(CRegKey& Key, SECURITY_INFORMATION SecurityInformation)
83        {
84                CSecurityDesc SecurityDescriptor;
85                GetRegKeySecurityDescriptor(Key, SecurityInformation, SecurityDescriptor);
86                return StringFromSecurityDescriptor(SecurityDescriptor);
87        }
88        CClassIdentifierRegKeySecurity(const CLSID& ClassIdentifier) throw() :
89                m_ClassIdentifier(ClassIdentifier)
90        {
91        }
92        ~CClassIdentifierRegKeySecurity()
93        {
94                if(m_Key)
95                {
96                        // QUES: Would it be more reliable to leave ownership and revert DACL only in absence of Restore privilege?
97                        __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) m_OriginalSecurityDescriptor))));
98                        _Z4(atlTraceGeneral, 4, _T("Owner & DACL: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION));
99                        __C(HRESULT_FROM_WIN32(m_Key.Close()));
100                }
101        }
102        BOOL Adjust()
103        {
104                if(m_Key)
105                        return FALSE;
106                const CString sKeyName = AtlFormatString(_T("CLSID\\%ls"), _PersistHelper::StringFromIdentifier(m_ClassIdentifier));
107                __C(HRESULT_FROM_WIN32(m_Key.Open(HKEY_CLASSES_ROOT, sKeyName, READ_CONTROL | WRITE_OWNER)));
108                GetRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION, m_OriginalSecurityDescriptor);
109                // NOTE:
110                //   Windows 5.1: CLSID {E436EBB2-524F-11CE-9F53-0020AF0BA770} Key Owner: O:AB (Administrators)
111                //   Windows 6.0: CLSID {E436EBB2-524F-11CE-9F53-0020AF0BA770} Key Owner: O:S-1-5-80-956008885-3418522649-1831038044-1853292631-2271478464 (TrustedInstaller)
112                _Z4(atlTraceGeneral, 4, _T("ClassIdentifier %ls, m_OriginalSecurityDescriptor %s\n"), _PersistHelper::StringFromIdentifier(m_ClassIdentifier), StringFromSecurityDescriptor(m_OriginalSecurityDescriptor));
113                // NOTE: Take ownership of the key to Administrators in order to be able to update key DACL
114                // QUES: Would it be more reliable to take ownership to self in absence of Restore privilege?
115                CSecurityDesc AdministratorsOwnerSecurityDescriptor;
116                AdministratorsOwnerSecurityDescriptor.SetOwner(Sids::Admins());
117                __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(OWNER_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) AdministratorsOwnerSecurityDescriptor))));
118                _Z4(atlTraceGeneral, 4, _T("Owner: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION));
119                // NOTE: Reopen the key to obtain given privileges
120                __C(HRESULT_FROM_WIN32(m_Key.Close()));
121                __C(HRESULT_FROM_WIN32(m_Key.Open(HKEY_CLASSES_ROOT, sKeyName, READ_CONTROL | WRITE_DAC | WRITE_OWNER)));
122                // NOTE: Adjust key DACL in order to make the key writable
123                CSecurityDesc AccessListSecurityDescriptor = m_OriginalSecurityDescriptor;
124                CDacl AccessList;
125                __E(AccessListSecurityDescriptor.GetDacl(&AccessList));
126                _W(AccessList.AddAllowedAce(Sids::Admins(), GENERIC_ALL, CONTAINER_INHERIT_ACE));
127                AccessListSecurityDescriptor.SetDacl(AccessList);
128                __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(DACL_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) AccessListSecurityDescriptor))));
129                _Z4(atlTraceGeneral, 4, _T("DACL: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, DACL_SECURITY_INFORMATION));
130                return TRUE;
131        }
132};
133
134////////////////////////////////////////////////////////////
135// TreatAsUpdateRegistryFromResource
136
137template <typename T>
138inline VOID TreatAsUpdateRegistryFromResource(const CLSID& TreatAsClassIdentifier, BOOL bRegister)
139{
140        _Z2(atlTraceRegistrar, 2, _T("TreatAsClassIdentifier %ls, bRegister %d\n"), _PersistHelper::StringFromIdentifier(TreatAsClassIdentifier), bRegister);
141        // NOTE: Registration is much more sophisticated starting Vista operating system
142        const ULONG nOsVersion = GetOsVersion();
143        _Z4(atlTraceRegistrar, 4, _T("nOsVersion 0x%08x\n"), nOsVersion);
144        CProcessTokenPrivileges ProcessTokenPrivileges;
145        if(nOsVersion >= 0x060000) // Win Vista+
146                _ATLTRY
147                {
148                        ProcessTokenPrivileges.Adjust();
149                }
150                _ATLCATCHALL()
151                {
152                        _Z_EXCEPTION();
153                }
154        CLSID CurrentTreatAsClassIdentifier = CLSID_NULL;
155        const HRESULT nCoGetTreatAsClassResult = CoGetTreatAsClass(TreatAsClassIdentifier, &CurrentTreatAsClassIdentifier);
156        __C(nCoGetTreatAsClassResult);
157        _Z4(atlTraceRegistrar, 4, _T("bRegister %d, nCoGetTreatAsClassResult 0x%08x, CurrentTreatAsClassIdentifier %ls\n"), bRegister, nCoGetTreatAsClassResult, _PersistHelper::StringFromIdentifier(CurrentTreatAsClassIdentifier));
158        __D(!bRegister || nCoGetTreatAsClassResult != S_OK || CurrentTreatAsClassIdentifier == T::GetObjectCLSID(), E_UNNAMED);
159        CClassIdentifierRegKeySecurity ClassIdentifierRegKeySecurity(TreatAsClassIdentifier);
160        if(!bRegister && nCoGetTreatAsClassResult == S_OK)
161        {
162                if(nOsVersion >= 0x060000) // Win Vista+
163                        ClassIdentifierRegKeySecurity.Adjust();
164                __C(CoTreatAsClass(TreatAsClassIdentifier, CLSID_NULL));
165        }
166        _A(_pAtlModule);
167        UpdateRegistryFromResource<T>(bRegister);
168        if(bRegister)
169        {
170                if(nOsVersion >= 0x060000) // Win Vista+
171                        ClassIdentifierRegKeySecurity.Adjust();
172                #if _DEVELOPMENT
173                        const HRESULT nCoTreatAsClassResult = CoTreatAsClass(TreatAsClassIdentifier, T::GetObjectCLSID());
174                        _Z2(atlTraceRegistrar, SUCCEEDED(nCoTreatAsClassResult) ? 4 : 2, _T("nCoTreatAsClassResult 0x%08x\n"), nCoTreatAsClassResult);
175                        __C(nCoTreatAsClassResult);
176                        const HRESULT nCoGetTreatAsClassResult = CoGetTreatAsClass(TreatAsClassIdentifier, &CurrentTreatAsClassIdentifier);
177                        _Z4(atlTraceRegistrar, 4, _T("nCoGetTreatAsClassResult 0x%08x, CurrentTreatAsClassIdentifier %ls\n"), nCoGetTreatAsClassResult, _PersistHelper::StringFromIdentifier(CurrentTreatAsClassIdentifier));
178                        _A(CurrentTreatAsClassIdentifier == T::GetObjectCLSID());
179                #else
180                        __C(CoTreatAsClass(TreatAsClassIdentifier, T::GetObjectCLSID()));
181                #endif // _DEVELOPMENT
182        }
183}
184
185////////////////////////////////////////////////////////////
186// CBlackListAwareComCreatorT
187
188template <typename T, const CLSID* t_pClassIdentifier>
189class ATL_NO_VTABLE CTransparentCoClassT
190{
191public:
192// CTransparentCoClassT
193        static HINSTANCE CoLoadOriginalLibrary()
194        {
195                const HINSTANCE hModule = CoLoadLibrary(const_cast<LPOLESTR>((LPCOLESTR) CT2COLE(T::GetOriginalLibraryName())), TRUE);
196                __E(hModule);
197                return hModule;
198        }
199        static CComPtr<IUnknown> CoCreateOriginalInstance(HINSTANCE hModule, IUnknown* pControllingUnknown = NULL)
200        {
201                typedef HRESULT (WINAPI *DLLGETCLASSOBJECT)(REFCLSID, REFIID, VOID**);
202                DLLGETCLASSOBJECT DllGetClassObject = (DLLGETCLASSOBJECT) GetProcAddress(hModule, "DllGetClassObject");
203                __E(DllGetClassObject);
204                CComPtr<IClassFactory> pClassFactory;
205                __C(DllGetClassObject(*t_pClassIdentifier, __uuidof(IClassFactory), (VOID**) &pClassFactory));
206                _A(pClassFactory);
207                CComPtr<IUnknown> pUnknown;
208                __C(pClassFactory->CreateInstance(pControllingUnknown, __uuidof(IUnknown), (VOID**) &pUnknown));
209                return pUnknown;
210        }
211        static CComPtr<IUnknown> CoCreateOriginalInstance(IUnknown* pControllingUnknown = NULL)
212        {
213                CComPtr<IUnknown> pUnknown;
214                const HINSTANCE hModule = CoLoadOriginalLibrary();
215                _ATLTRY
216                {
217                        pUnknown = CoCreateOriginalInstance(hModule, pControllingUnknown);
218                }
219                _ATLCATCHALL()
220                {
221                        CoFreeLibrary(hModule);
222                        _ATLRETHROW;
223                }
224                CoFreeLibrary(hModule);
225                return pUnknown;
226        }
227};
228
229////////////////////////////////////////////////////////////
230// CBlackListAwareComCreatorT
231
232template <typename _ObjectClass, typename _Class, LPCTSTR* t_ppszName>
233class CBlackListAwareComCreatorT :
234        public CComCreator<_ObjectClass>
235{
236public:
237// CBlackListAwareComCreatorT
238        static HRESULT WINAPI CreateInstance(VOID* pvControllingUnknown, REFIID InterfaceIdentifier, VOID** ppvObject) throw()
239        {
240                _A(ppvObject);
241                *ppvObject = NULL;
242                #pragma region Check Black List
243                static INT g_nBlackListed = 0; // 0 Unknown, 1 No, 2 Yes
244                {
245                        _A(_pAtlModule);
246                        CComCritSecLock<CComCriticalSection> Lock(_pAtlModule->m_csStaticDataInitAndTypeInfo);
247                        if(!g_nBlackListed)
248                        {
249                                TCHAR pszPath[MAX_PATH] = { 0 };
250                                _W(GetModuleFileName(NULL, pszPath, DIM(pszPath)));
251                                LPTSTR pszFileName = FindFileName(pszPath);
252                                _A(pszFileName);
253                                RemoveExtension(pszFileName);
254                                const CString sBlackList = _RegKeyHelper::QueryStringValue(HKEY_LOCAL_MACHINE, REGISTRY_ROOT, AtlFormatString(_T("%s Black List"), *t_ppszName));
255                                CRoArrayT<CString> BlackListArray;
256                                _StringHelper::GetCommaSeparatedItems(sBlackList, BlackListArray);
257                                BOOL bFound = FALSE;
258                                for(SIZE_T nIndex = 0; nIndex < BlackListArray.GetCount(); nIndex++)
259                                {
260                                        CPath sFileName = (LPCTSTR) BlackListArray[nIndex];
261                                        sFileName.RemoveExtension();
262                                        if(_tcsicmp(sFileName, pszFileName) == 0)
263                                        {
264                                                _Z2(atlTraceCOM, 2, _T("Will instantiate original COM class, sFileName \"%s\"\n"), sFileName);
265                                                bFound = TRUE;
266                                                break;
267                                        }
268                                }
269                                g_nBlackListed = bFound ? 2 : 1;
270                        }
271                }
272                #pragma endregion
273                if(g_nBlackListed == 2)
274                        #pragma region CoCreateInstance Original Class
275                        _ATLTRY
276                        {
277                                _A(_pAtlModule);
278                                const LONG nLockResult = _pAtlModule->Lock();
279                                _ATLTRY
280                                {
281                                        CComPtr<IUnknown> pUnknown = _Class::CoCreateOriginalInstance((IUnknown*) pvControllingUnknown);
282                                        if(InterfaceIdentifier == __uuidof(IUnknown))
283                                                *ppvObject = pUnknown.Detach();
284                                        else
285                                                __C(pUnknown->QueryInterface(InterfaceIdentifier, ppvObject));
286                                }
287                                _ATLCATCHALL()
288                                {
289                                        _pAtlModule->Unlock();
290                                        _ATLRETHROW;
291                                }
292                                const LONG nUnlockResult = _pAtlModule->Unlock();
293                                _Z6(atlTraceGeneral, 6, _T("nLockResult %d, nUnlockResult %d\n"), nLockResult, nUnlockResult);
294                                return S_OK;
295                        }
296                        _ATLCATCH(Exception)
297                        {
298                                _A(FAILED(Exception));
299                                _C(Exception);
300                        }
301                        #pragma endregion
302                return __super::CreateInstance(pvControllingUnknown, InterfaceIdentifier, ppvObject);
303        }
304};
305
306////////////////////////////////////////////////////////////
307// CHookHostT
308
309template <typename T, typename IHook, LPCTSTR* t_ppszHookName>
310class CHookHostT
311{
312public:
313
314        ////////////////////////////////////////////////////////
315        // CHookArray
316
317        class CHookArray :
318                public CRoArrayT<CComPtr<IHook> >
319        {
320        public:
321        // CHookArray
322        };
323
324private:
325        mutable CRoCriticalSection m_HookCriticalSection;
326        BOOL m_bHookArrayInitialized;
327        CHookArray m_HookArray;
328
329        VOID InitializeHookArray()
330        {
331                _A(m_HookArray.IsEmpty());
332                CRoListT<CLSID> ClassIdentifierList;
333                static const HKEY g_phParentKeys[] = { HKEY_LOCAL_MACHINE, HKEY_CURRENT_USER };
334                static const LPCTSTR g_ppszKeyNameFormats[] = { _T("SOFTWARE\\Classes\\"), _T("Software\\Classes\\") };
335                for(SIZE_T nKeyIndex = 0; nKeyIndex < DIM(g_phParentKeys); nKeyIndex++)
336                {
337                        const CString sKeyName = AtlFormatString(_T("%sCLSID\\%ls\\Hooks\\%s"), g_ppszKeyNameFormats[nKeyIndex], _PersistHelper::StringFromIdentifier(T::GetObjectCLSID()), *t_ppszHookName);
338                        CRegKey Key;
339                        if(FAILED(HRESULT_FROM_WIN32(Key.Open(g_phParentKeys[nKeyIndex], sKeyName, KEY_READ))))
340                                continue;
341                        for(DWORD nIndex = 0; ; nIndex++)
342                        {
343                                DWORD nNameLength = 0;
344                                RegEnumKeyEx(Key, nIndex, NULL, &nNameLength, NULL, NULL, NULL, NULL);
345                                nNameLength = max(2 * nNameLength, 256);
346                                CTempBuffer<TCHAR, 4096> pszName(nNameLength);
347                                const HRESULT nRegEnumKeyResult = HRESULT_FROM_WIN32(RegEnumKeyEx(Key, nIndex, pszName, &nNameLength, NULL, NULL, NULL, NULL));
348                                if(FAILED(nRegEnumKeyResult))
349                                {
350                                        __D(nRegEnumKeyResult == HRESULT_FROM_WIN32(ERROR_NO_MORE_ITEMS), nRegEnumKeyResult);
351                                        break;
352                                }
353                                _ATLTRY
354                                {
355                                        const CLSID ClassIdentifier = _PersistHelper::ClassIdentifierFromString(CT2CW(pszName));
356                                        _Z4(atlTraceGeneral, 4, _T("ClassIdentifier %ls\n"), _PersistHelper::StringFromIdentifier(ClassIdentifier));
357                                        __D(ClassIdentifier != CLSID_NULL, E_UNNAMED);
358                                        if(ClassIdentifierList.Find(ClassIdentifier))
359                                                continue;
360                                        _W(ClassIdentifierList.AddTail(ClassIdentifier));
361                                        CComPtr<IHook> pHook;
362                                        __C(pHook.CoCreateInstance(ClassIdentifier));
363                                        _W(m_HookArray.Add(pHook) >= 0);
364                                }
365                                _ATLCATCHALL()
366                                {
367                                        _Z_EXCEPTION();
368                                }
369                        }
370                }
371        }
372
373public:
374// CHookHostT
375        CHookHostT() throw() :
376                m_bHookArrayInitialized(FALSE)
377        {
378        }
379        SIZE_T GetHookArray(CHookArray& HookArray)
380        {
381                _A(HookArray.IsEmpty());
382                CRoCriticalSectionLock HookLock(m_HookCriticalSection);
383                if(!m_bHookArrayInitialized)
384                        _ATLTRY
385                        {
386                                m_bHookArrayInitialized = TRUE;
387                                InitializeHookArray();
388                        }
389                        _ATLCATCHALL()
390                        {
391                                _Z_EXCEPTION();
392                        }
393                HookArray.Append(m_HookArray);
394                return HookArray.GetCount();
395        }
396};
397
398#define HOOK_PROLOG(Base) \
399        _ATLTRY \
400        { \
401                Base::CHookArray HookArray; \
402                if(Base::GetHookArray(HookArray)) \
403                { \
404                        T* pT = static_cast<T*>(this); \
405                        for(SIZE_T nIndex = 0; nIndex < HookArray.GetCount(); nIndex++) \
406                        { \
407                                BOOL bDefault = TRUE; \
408                                const HRESULT nResult = HookArray[nIndex]->
409                                               
410#define HOOK_EPILOG() \
411                                if(!bDefault) \
412                                        return nResult; \
413                                _A(SUCCEEDED(nResult)); \
414                        } \
415                } \
416        } \
417        _ATLCATCHALL() \
418        { \
419                _Z_EXCEPTION(); \
420        }
421
Note: See TracBrowser for help on using the repository browser.