source: trunk/DirectShowSpy/Common.h @ 196

Last change on this file since 196 was 196, checked in by roman, 9 years ago

Cosmetic fixes, new .BAT names, UnregisterTreatAsClasses? export to force removal of TreatAs? keys

  • Property svn:keywords set to Id
File size: 15.3 KB
Line 
1////////////////////////////////////////////////////////////
2// Copyright (C) Roman Ryltsov, 2008-2011
3// Created by Roman Ryltsov roman@alax.info
4
5#pragma once
6
7////////////////////////////////////////////////////////////
8// CProcessTokenPrivileges
9
10class CProcessTokenPrivileges
11{
12public:
13        CAccessToken m_ProcessToken;
14        CTokenPrivileges m_OriginalProcessTokenPrivileges;
15        BOOL m_bTakeOwnershipPrivilegeEnabled;
16        BOOL m_bRestorePrivilegeEnabled;
17
18public:
19// CProcessTokenPrivileges
20        CProcessTokenPrivileges() throw() :
21                m_bTakeOwnershipPrivilegeEnabled(FALSE),
22                m_bRestorePrivilegeEnabled(FALSE)
23        {
24        }
25        ~CProcessTokenPrivileges()
26        {
27                if(m_bTakeOwnershipPrivilegeEnabled)
28                {
29                        __E(m_ProcessToken.EnableDisablePrivileges(m_OriginalProcessTokenPrivileges));
30                        m_bRestorePrivilegeEnabled = FALSE;
31                        m_bTakeOwnershipPrivilegeEnabled = FALSE;
32                }
33        }
34        VOID Adjust()
35        {
36                _A(!m_ProcessToken.GetHandle());
37                // NOTE: Enable SE_TAKE_OWNERSHIP_NAME and SE_RESTORE_NAME privileges in order to be able to access registry key permissions
38                __E(m_ProcessToken.GetProcessToken(TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES));
39                __E(m_ProcessToken.EnablePrivilege(SE_TAKE_OWNERSHIP_NAME, &m_OriginalProcessTokenPrivileges));
40                m_bTakeOwnershipPrivilegeEnabled = TRUE;
41                __E(m_ProcessToken.EnablePrivilege(SE_RESTORE_NAME));
42                m_bRestorePrivilegeEnabled = TRUE;
43        }
44};
45
46////////////////////////////////////////////////////////////
47// CClassIdentifierRegKeySecurity
48
49class CClassIdentifierRegKeySecurity
50{
51public:
52        CLSID m_ClassIdentifier;
53        CRegKey m_Key;
54        CSecurityDesc m_OriginalSecurityDescriptor;
55
56public:
57// CClassIdentifierRegKeySecurity
58        static VOID GetRegKeySecurityDescriptor(CRegKey& Key, SECURITY_INFORMATION SecurityInformation, CSecurityDesc& SecurityDescriptor)
59        {
60                DWORD nSecurityDescriptorSize = 0;
61                Key.GetKeySecurity(SecurityInformation, NULL, &nSecurityDescriptorSize);
62                CTempBuffer<SECURITY_DESCRIPTOR> pSecurityDescriptor;
63                _W(pSecurityDescriptor.AllocateBytes(nSecurityDescriptorSize));
64                __C(HRESULT_FROM_WIN32(Key.GetKeySecurity(SecurityInformation, pSecurityDescriptor, &nSecurityDescriptorSize)));
65                SecurityDescriptor = *pSecurityDescriptor;
66        }
67        static CString StringFromSecurityDescriptor(CSecurityDesc& SecurityDescriptor)
68        {
69                CString sSecurityDescriptorString;
70                __E(SecurityDescriptor.ToString(&sSecurityDescriptorString));
71                return sSecurityDescriptorString;
72        }
73        static CString StringFromRegKeySecurityDescriptor(CRegKey& Key, SECURITY_INFORMATION SecurityInformation)
74        {
75                CSecurityDesc SecurityDescriptor;
76                GetRegKeySecurityDescriptor(Key, SecurityInformation, SecurityDescriptor);
77                return StringFromSecurityDescriptor(SecurityDescriptor);
78        }
79        CClassIdentifierRegKeySecurity(const CLSID& ClassIdentifier) throw() :
80                m_ClassIdentifier(ClassIdentifier)
81        {
82        }
83        ~CClassIdentifierRegKeySecurity()
84        {
85                if(m_Key)
86                {
87                        // QUES: Would it be more reliable to leave ownership and revert DACL only in absence of Restore privilege?
88                        __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) m_OriginalSecurityDescriptor))));
89                        _Z4(atlTraceGeneral, 4, _T("Owner & DACL: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION));
90                        __C(HRESULT_FROM_WIN32(m_Key.Close()));
91                }
92        }
93        BOOL Adjust()
94        {
95                if(m_Key)
96                        return FALSE;
97                const CString sKeyName = AtlFormatString(_T("CLSID\\%ls"), _PersistHelper::StringFromIdentifier(m_ClassIdentifier));
98                __C(HRESULT_FROM_WIN32(m_Key.Open(HKEY_CLASSES_ROOT, sKeyName, READ_CONTROL | WRITE_OWNER)));
99                GetRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION, m_OriginalSecurityDescriptor);
100                // NOTE:
101                //   Windows 5.1: CLSID {E436EBB2-524F-11CE-9F53-0020AF0BA770} Key Owner: O:AB (Administrators)
102                //   Windows 6.0: CLSID {E436EBB2-524F-11CE-9F53-0020AF0BA770} Key Owner: O:S-1-5-80-956008885-3418522649-1831038044-1853292631-2271478464 (TrustedInstaller)
103                _Z4(atlTraceGeneral, 4, _T("ClassIdentifier %ls, m_OriginalSecurityDescriptor %s\n"), _PersistHelper::StringFromIdentifier(m_ClassIdentifier), StringFromSecurityDescriptor(m_OriginalSecurityDescriptor));
104                // NOTE: Take ownership of the key to Administrators in order to be able to update key DACL
105                // QUES: Would it be more reliable to take ownership to self in absence of Restore privilege?
106                CSecurityDesc AdministratorsOwnerSecurityDescriptor;
107                AdministratorsOwnerSecurityDescriptor.SetOwner(Sids::Admins());
108                __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(OWNER_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) AdministratorsOwnerSecurityDescriptor))));
109                _Z4(atlTraceGeneral, 4, _T("Owner: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION));
110                // NOTE: Reopen the key to obtain given privileges
111                __C(HRESULT_FROM_WIN32(m_Key.Close()));
112                __C(HRESULT_FROM_WIN32(m_Key.Open(HKEY_CLASSES_ROOT, sKeyName, READ_CONTROL | WRITE_DAC | WRITE_OWNER)));
113                // NOTE: Adjust key DACL in order to make the key writable
114                CSecurityDesc AccessListSecurityDescriptor = m_OriginalSecurityDescriptor;
115                CDacl AccessList;
116                __E(AccessListSecurityDescriptor.GetDacl(&AccessList));
117                _W(AccessList.AddAllowedAce(Sids::Admins(), GENERIC_ALL, CONTAINER_INHERIT_ACE));
118                AccessListSecurityDescriptor.SetDacl(AccessList);
119                __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(DACL_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) AccessListSecurityDescriptor))));
120                _Z4(atlTraceGeneral, 4, _T("DACL: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, DACL_SECURITY_INFORMATION));
121                return TRUE;
122        }
123};
124
125////////////////////////////////////////////////////////////
126// TreatAsUpdateRegistryFromResource
127
128template <typename T>
129inline VOID TreatAsUpdateRegistryFromResource(const CLSID& TreatAsClassIdentifier, BOOL bRegister)
130{
131        _Z2(atlTraceRegistrar, 2, _T("TreatAsClassIdentifier %ls, bRegister %d\n"), _PersistHelper::StringFromIdentifier(TreatAsClassIdentifier), bRegister);
132        // NOTE: Registration is much more sophisticated starting Vista operating system
133        const ULONG nOsVersion = GetOsVersion();
134        _Z4(atlTraceRegistrar, 4, _T("nOsVersion 0x%08x\n"), nOsVersion);
135        CProcessTokenPrivileges ProcessTokenPrivileges;
136        if(nOsVersion >= 0x060000) // Win Vista+
137                _ATLTRY
138                {
139                        ProcessTokenPrivileges.Adjust();
140                }
141                _ATLCATCHALL()
142                {
143                        _Z_EXCEPTION();
144                }
145        CLSID CurrentTreatAsClassIdentifier = CLSID_NULL;
146        const HRESULT nCoGetTreatAsClassResult = CoGetTreatAsClass(TreatAsClassIdentifier, &CurrentTreatAsClassIdentifier);
147        __C(nCoGetTreatAsClassResult);
148        _Z4(atlTraceRegistrar, 4, _T("bRegister %d, nCoGetTreatAsClassResult 0x%08x, CurrentTreatAsClassIdentifier %ls\n"), bRegister, nCoGetTreatAsClassResult, _PersistHelper::StringFromIdentifier(CurrentTreatAsClassIdentifier));
149        __D(!bRegister || nCoGetTreatAsClassResult != S_OK || CurrentTreatAsClassIdentifier == T::GetObjectCLSID(), E_UNNAMED);
150        CClassIdentifierRegKeySecurity ClassIdentifierRegKeySecurity(TreatAsClassIdentifier);
151        if(!bRegister && nCoGetTreatAsClassResult == S_OK)
152        {
153                if(nOsVersion >= 0x060000) // Win Vista+
154                        ClassIdentifierRegKeySecurity.Adjust();
155                __C(CoTreatAsClass(TreatAsClassIdentifier, CLSID_NULL));
156        }
157        _A(_pAtlModule);
158        UpdateRegistryFromResource<T>(bRegister);
159        if(bRegister)
160        {
161                if(nOsVersion >= 0x060000) // Win Vista+
162                        ClassIdentifierRegKeySecurity.Adjust();
163                #if _DEVELOPMENT
164                        const HRESULT nCoTreatAsClassResult = CoTreatAsClass(TreatAsClassIdentifier, T::GetObjectCLSID());
165                        _Z2(atlTraceRegistrar, SUCCEEDED(nCoTreatAsClassResult) ? 4 : 2, _T("nCoTreatAsClassResult 0x%08x\n"), nCoTreatAsClassResult);
166                        __C(nCoTreatAsClassResult);
167                        const HRESULT nCoGetTreatAsClassResult = CoGetTreatAsClass(TreatAsClassIdentifier, &CurrentTreatAsClassIdentifier);
168                        _Z4(atlTraceRegistrar, 4, _T("nCoGetTreatAsClassResult 0x%08x, CurrentTreatAsClassIdentifier %ls\n"), nCoGetTreatAsClassResult, _PersistHelper::StringFromIdentifier(CurrentTreatAsClassIdentifier));
169                        _A(CurrentTreatAsClassIdentifier == T::GetObjectCLSID());
170                #else
171                        __C(CoTreatAsClass(TreatAsClassIdentifier, T::GetObjectCLSID()));
172                #endif // _DEVELOPMENT
173        }
174}
175
176////////////////////////////////////////////////////////////
177// CBlackListAwareComCreatorT
178
179template <typename T, const CLSID* t_pClassIdentifier>
180class ATL_NO_VTABLE CTransparentCoClassT
181{
182public:
183// CTransparentCoClassT
184        static HINSTANCE CoLoadOriginalLibrary()
185        {
186                const HINSTANCE hModule = CoLoadLibrary(const_cast<LPOLESTR>((LPCOLESTR) CT2COLE(T::GetOriginalLibraryName())), TRUE);
187                __E(hModule);
188                return hModule;
189        }
190        static CComPtr<IUnknown> CoCreateOriginalInstance(HINSTANCE hModule, IUnknown* pControllingUnknown = NULL)
191        {
192                typedef HRESULT (WINAPI *DLLGETCLASSOBJECT)(REFCLSID, REFIID, VOID**);
193                DLLGETCLASSOBJECT DllGetClassObject = (DLLGETCLASSOBJECT) GetProcAddress(hModule, "DllGetClassObject");
194                __E(DllGetClassObject);
195                CComPtr<IClassFactory> pClassFactory;
196                __C(DllGetClassObject(*t_pClassIdentifier, __uuidof(IClassFactory), (VOID**) &pClassFactory));
197                _A(pClassFactory);
198                CComPtr<IUnknown> pUnknown;
199                __C(pClassFactory->CreateInstance(pControllingUnknown, __uuidof(IUnknown), (VOID**) &pUnknown));
200                return pUnknown;
201        }
202        static CComPtr<IUnknown> CoCreateOriginalInstance(IUnknown* pControllingUnknown = NULL)
203        {
204                CComPtr<IUnknown> pUnknown;
205                const HINSTANCE hModule = CoLoadOriginalLibrary();
206                _ATLTRY
207                {
208                        pUnknown = CoCreateOriginalInstance(hModule, pControllingUnknown);
209                }
210                _ATLCATCHALL()
211                {
212                        CoFreeLibrary(hModule);
213                        _ATLRETHROW;
214                }
215                CoFreeLibrary(hModule);
216                return pUnknown;
217        }
218};
219
220////////////////////////////////////////////////////////////
221// CBlackListAwareComCreatorT
222
223template <typename _ObjectClass, typename _Class, LPCTSTR* t_ppszName>
224class CBlackListAwareComCreatorT :
225        public CComCreator<_ObjectClass>
226{
227public:
228// CBlackListAwareComCreatorT
229        static HRESULT WINAPI CreateInstance(VOID* pvControllingUnknown, REFIID InterfaceIdentifier, VOID** ppvObject) throw()
230        {
231                _A(ppvObject);
232                *ppvObject = NULL;
233                #pragma region Check Black List
234                static INT g_nBlackListed = 0; // 0 Unknown, 1 No, 2 Yes
235                {
236                        _A(_pAtlModule);
237                        CComCritSecLock<CComCriticalSection> Lock(_pAtlModule->m_csStaticDataInitAndTypeInfo);
238                        if(!g_nBlackListed)
239                        {
240                                TCHAR pszPath[MAX_PATH] = { 0 };
241                                _W(GetModuleFileName(NULL, pszPath, DIM(pszPath)));
242                                LPTSTR pszFileName = FindFileName(pszPath);
243                                _A(pszFileName);
244                                RemoveExtension(pszFileName);
245                                const CString sBlackList = _RegKeyHelper::QueryStringValue(HKEY_LOCAL_MACHINE, REGISTRY_ROOT, AtlFormatString(_T("%s Black List"), *t_ppszName));
246                                CRoArrayT<CString> BlackListArray;
247                                _StringHelper::GetCommaSeparatedItems(sBlackList, BlackListArray);
248                                BOOL bFound = FALSE;
249                                for(SIZE_T nIndex = 0; nIndex < BlackListArray.GetCount(); nIndex++)
250                                {
251                                        CPath sFileName = (LPCTSTR) BlackListArray[nIndex];
252                                        sFileName.RemoveExtension();
253                                        if(_tcsicmp(sFileName, pszFileName) == 0)
254                                        {
255                                                _Z2(atlTraceCOM, 2, _T("Will instantiate original COM class, sFileName \"%s\"\n"), sFileName);
256                                                bFound = TRUE;
257                                                break;
258                                        }
259                                }
260                                g_nBlackListed = bFound ? 2 : 1;
261                        }
262                }
263                #pragma endregion
264                if(g_nBlackListed == 2)
265                        #pragma region CoCreateInstance Original Class
266                        _ATLTRY
267                        {
268                                _A(_pAtlModule);
269                                const LONG nLockResult = _pAtlModule->Lock();
270                                _ATLTRY
271                                {
272                                        CComPtr<IUnknown> pUnknown = _Class::CoCreateOriginalInstance((IUnknown*) pvControllingUnknown);
273                                        if(InterfaceIdentifier == __uuidof(IUnknown))
274                                                *ppvObject = pUnknown.Detach();
275                                        else
276                                                __C(pUnknown->QueryInterface(InterfaceIdentifier, ppvObject));
277                                }
278                                _ATLCATCHALL()
279                                {
280                                        _pAtlModule->Unlock();
281                                        _ATLRETHROW;
282                                }
283                                const LONG nUnlockResult = _pAtlModule->Unlock();
284                                _Z6(atlTraceGeneral, 6, _T("nLockResult %d, nUnlockResult %d\n"), nLockResult, nUnlockResult);
285                                return S_OK;
286                        }
287                        _ATLCATCH(Exception)
288                        {
289                                _A(FAILED(Exception));
290                                _C(Exception);
291                        }
292                        #pragma endregion
293                return __super::CreateInstance(pvControllingUnknown, InterfaceIdentifier, ppvObject);
294        }
295};
296
297////////////////////////////////////////////////////////////
298// CHookHostT
299
300template <typename T, typename IHook, LPCTSTR* t_ppszHookName>
301class CHookHostT
302{
303public:
304
305        ////////////////////////////////////////////////////////
306        // CHookArray
307
308        class CHookArray :
309                public CRoArrayT<CComPtr<IHook> >
310        {
311        public:
312        // CHookArray
313        };
314
315private:
316        mutable CRoCriticalSection m_HookCriticalSection;
317        BOOL m_bHookArrayInitialized;
318        CHookArray m_HookArray;
319
320        VOID InitializeHookArray()
321        {
322                _A(m_HookArray.IsEmpty());
323                CRoListT<CLSID> ClassIdentifierList;
324                static const HKEY g_phParentKeys[] = { HKEY_LOCAL_MACHINE, HKEY_CURRENT_USER };
325                static const LPCTSTR g_ppszKeyNameFormats[] = { _T("SOFTWARE\\Classes\\"), _T("Software\\Classes\\") };
326                for(SIZE_T nKeyIndex = 0; nKeyIndex < DIM(g_phParentKeys); nKeyIndex++)
327                {
328                        const CString sKeyName = AtlFormatString(_T("%sCLSID\\%ls\\Hooks\\%s"), g_ppszKeyNameFormats[nKeyIndex], _PersistHelper::StringFromIdentifier(T::GetObjectCLSID()), *t_ppszHookName);
329                        CRegKey Key;
330                        if(FAILED(HRESULT_FROM_WIN32(Key.Open(g_phParentKeys[nKeyIndex], sKeyName, KEY_READ))))
331                                continue;
332                        for(DWORD nIndex = 0; ; nIndex++)
333                        {
334                                DWORD nNameLength = 0;
335                                RegEnumKeyEx(Key, nIndex, NULL, &nNameLength, NULL, NULL, NULL, NULL);
336                                nNameLength = max(2 * nNameLength, 256);
337                                CTempBuffer<TCHAR, 4096> pszName(nNameLength);
338                                const HRESULT nRegEnumKeyResult = HRESULT_FROM_WIN32(RegEnumKeyEx(Key, nIndex, pszName, &nNameLength, NULL, NULL, NULL, NULL));
339                                if(FAILED(nRegEnumKeyResult))
340                                {
341                                        __D(nRegEnumKeyResult == HRESULT_FROM_WIN32(ERROR_NO_MORE_ITEMS), nRegEnumKeyResult);
342                                        break;
343                                }
344                                _ATLTRY
345                                {
346                                        const CLSID ClassIdentifier = _PersistHelper::ClassIdentifierFromString(CT2CW(pszName));
347                                        _Z4(atlTraceGeneral, 4, _T("ClassIdentifier %ls\n"), _PersistHelper::StringFromIdentifier(ClassIdentifier));
348                                        __D(ClassIdentifier != CLSID_NULL, E_UNNAMED);
349                                        if(ClassIdentifierList.Find(ClassIdentifier))
350                                                continue;
351                                        _W(ClassIdentifierList.AddTail(ClassIdentifier));
352                                        CComPtr<IHook> pHook;
353                                        __C(pHook.CoCreateInstance(ClassIdentifier));
354                                        _W(m_HookArray.Add(pHook) >= 0);
355                                }
356                                _ATLCATCHALL()
357                                {
358                                        _Z_EXCEPTION();
359                                }
360                        }
361                }
362        }
363
364public:
365// CHookHostT
366        CHookHostT() throw() :
367                m_bHookArrayInitialized(FALSE)
368        {
369        }
370        SIZE_T GetHookArray(CHookArray& HookArray)
371        {
372                _A(HookArray.IsEmpty());
373                CRoCriticalSectionLock HookLock(m_HookCriticalSection);
374                if(!m_bHookArrayInitialized)
375                        _ATLTRY
376                        {
377                                m_bHookArrayInitialized = TRUE;
378                                InitializeHookArray();
379                        }
380                        _ATLCATCHALL()
381                        {
382                                _Z_EXCEPTION();
383                        }
384                HookArray.Append(m_HookArray);
385                return HookArray.GetCount();
386        }
387};
388
389#define HOOK_PROLOG(Base) \
390        _ATLTRY \
391        { \
392                Base::CHookArray HookArray; \
393                if(Base::GetHookArray(HookArray)) \
394                { \
395                        T* pT = static_cast<T*>(this); \
396                        for(SIZE_T nIndex = 0; nIndex < HookArray.GetCount(); nIndex++) \
397                        { \
398                                BOOL bDefault = TRUE; \
399                                const HRESULT nResult = HookArray[nIndex]->
400                                               
401#define HOOK_EPILOG() \
402                                if(!bDefault) \
403                                        return nResult; \
404                                _A(SUCCEEDED(nResult)); \
405                        } \
406                } \
407        } \
408        _ATLCATCHALL() \
409        { \
410                _Z_EXCEPTION(); \
411        }
412
Note: See TracBrowser for help on using the repository browser.